From 80e530fbfa73f6509834ba2026550226b65705fa Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 24 Jul 2024 01:38:18 +0200 Subject: [PATCH 01/35] initial work draft for freenoise; needs massive cleanup --- src/diffusers/models/attention.py | 248 +++++++++++++----- .../animatediff/pipeline_animatediff.py | 37 ++- src/diffusers/pipelines/free_noise_utils.py | 31 +++ 3 files changed, 249 insertions(+), 67 deletions(-) create mode 100644 src/diffusers/pipelines/free_noise_utils.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f6969470c36e..9daa37648394 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -209,6 +209,30 @@ def forward( return encoder_hidden_states, hidden_states +def get_frame_indices(num_frames: int, context_length: int = 16, context_stride: int = 4): + batch_indices = [] + for i in range(0, num_frames - context_length, context_stride): + window_start = i + window_end = min(num_frames, i + context_length) + batch_indices.append((window_start, window_end)) + return batch_indices + +def get_frame_weights(num_frames: int, weight_type: str = "pyramid"): + if weight_type == "pyramid": + if num_frames % 2 == 0: + # num_frames = 4 => [1, 2, 2, 1] + weights = list(range(1, num_frames // 2 + 1)) + weights = weights + weights[::-1] + else: + # num_frames = 5 => [1, 2, 3, 2, 1] + weights = list(range(1, num_frames // 2 + 1)) + weights = weights + [num_frames // 2 + 1] + weights[::-1] + else: + raise ValueError(f"Invalid `weight_type`: {weight_type}") + + return weights + + @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" @@ -422,81 +446,185 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # TODO: This is just for trial because it is pretty hard to pass the free noise args from pipeline to here + # without making too many changes + context_length = 16 + context_stride = 4 + hardcoded_num_frames = 80 + # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention - batch_size = hidden_states.shape[0] + print("num_frames:", hidden_states.shape) + + if hidden_states.size(1) == hardcoded_num_frames: + # new implementation + hidden_states_original = hidden_states # [bhw, f, c] + frame_indices = get_frame_indices(hardcoded_num_frames, context_length, context_stride) + num_times_processed = torch.zeros((1, hardcoded_num_frames, 1), device=hidden_states_original.device) + processed_values = torch.zeros_like(hidden_states_original) + + for frame_start, frame_end in frame_indices: + weights = get_frame_weights(context_length, weight_type="pyramid") + weights_tensor = torch.ones_like(num_times_processed[:, frame_start : frame_end]) + weights_tensor *= torch.tensor(weights, device=hidden_states_original.device, dtype=hidden_states_original.dtype).unsqueeze(0).unsqueeze(-1) + + hidden_states = hidden_states_original[:, frame_start : frame_end] + + # Copied original implementation of norm1 and attn1 + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif self.norm_type == "ada_norm_single": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + processed_values[:, frame_start : frame_end] += hidden_states * weights_tensor + num_times_processed[:, frame_start : frame_end] += weights_tensor + + hidden_states = torch.where(num_times_processed > 0, processed_values / num_times_processed, processed_values).to(hidden_states_original.dtype) + else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - if self.norm_type == "ada_norm_zero": - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.norm_type == "ada_norm_single": - attn_output = gate_msa * attn_output + # old implementation + batch_size = hidden_states.shape[0] - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1.2 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - - # 3. Cross-Attention - if self.attn2 is not None: if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa else: - raise ValueError("Incorrect norm") + raise ValueError("Incorrect norm used") - if self.pos_embed is not None and self.norm_type != "ada_norm_single": + if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) - attn_output = self.attn2( + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, **cross_attention_kwargs, ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states # 4. Feed-forward # i2vgen doesn't have this norm 🤷‍♂️ diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index bc684259aeb8..a480470017bb 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -42,6 +42,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import FreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -72,6 +73,7 @@ class AnimateDiffPipeline( IPAdapterMixin, LoraLoaderMixin, FreeInitMixin, + FreeNoiseMixin, ): r""" Pipeline for text-to-video generation. @@ -394,15 +396,20 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + + video = [] + for i in range(0, latents.shape[0], decode_batch_size): + batch_latents = latents[i : i + decode_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -495,7 +502,6 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -516,6 +522,22 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) + + if self.free_noise_enabled and self._free_noise_shuffle: + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + # shuffle latents in every window + latents[:, :, window_start : window_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -569,6 +591,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_batch_size: int = 16, **kwargs, ): r""" @@ -808,7 +831,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, decode_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py new file mode 100644 index 000000000000..dff9ae8c2fa8 --- /dev/null +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + + +class FreeNoiseMixin: + r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" + + def enable_free_noise(self, context_length: Optional[int] = 16, context_stride: int = 4, shuffle: bool = True) -> None: + self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length + self._free_noise_context_stride = context_stride + self._free_noise_shuffle = shuffle + + def disable_free_noise(self) -> None: + self._free_noise_context_length = None + + @property + def free_noise_enabled(self): + return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None From 441d321152e06fcd17cc03feb4c1f545cc80b299 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 24 Jul 2024 23:54:29 +0200 Subject: [PATCH 02/35] fix freeinit bug --- src/diffusers/pipelines/free_init_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/free_init_utils.py b/src/diffusers/pipelines/free_init_utils.py index 4f7965a038c5..8b8e0bd7da71 100644 --- a/src/diffusers/pipelines/free_init_utils.py +++ b/src/diffusers/pipelines/free_init_utils.py @@ -180,6 +180,6 @@ def _apply_free_init( num_inference_steps = max( 1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1)) ) - self.scheduler.set_timesteps(num_inference_steps, device=device) - + + self.scheduler.set_timesteps(num_inference_steps, device=device) return latents, self.scheduler.timesteps From 5d0f4c34076c145f0d5ade0d89394b11cd1d4d84 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 24 Jul 2024 23:54:42 +0200 Subject: [PATCH 03/35] add animatediff controlnet implementation --- src/diffusers/models/attention.py | 2 +- .../pipeline_animatediff_controlnet.py | 1073 +++++++++++++++++ 2 files changed, 1074 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9daa37648394..65b4d67bb8ff 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -454,7 +454,7 @@ def forward( # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention - print("num_frames:", hidden_states.shape) + # print("num_frames:", hidden_states.shape) if hidden_states.size(1) == hardcoded_num_frames: # new implementation diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py new file mode 100644 index 000000000000..1a789316cb01 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -0,0 +1,1073 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..controlnet.multicontrolnet import MultiControlNetModel +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import FreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter + >>> from diffusers.pipelines import DiffusionPipeline + >>> from diffusers.schedulers import DPMSolverMultistepScheduler + >>> from PIL import Image + + >>> motion_id = "guoyww/animatediff-motion-adapter-v1-5-2" + >>> adapter = MotionAdapter.from_pretrained(motion_id) + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16) + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + >>> pipe = DiffusionPipeline.from_pretrained( + ... model_id, + ... motion_adapter=adapter, + ... controlnet=controlnet, + ... vae=vae, + ... custom_pipeline="pipeline_animatediff_controlnet", + ... ).to(device="cuda", dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( + ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", + ... ) + >>> pipe.enable_vae_slicing() + + >>> conditioning_frames = [] + >>> for i in range(1, 16 + 1): + ... conditioning_frames.append(Image.open(f"frame_{i}.png")) + + >>> prompt = "astronaut in space, dancing" + >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" + >>> result = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=512, + ... height=768, + ... conditioning_frames=conditioning_frames, + ... num_inference_steps=12, + ... ) + + >>> from diffusers.utils import export_to_gif + >>> export_to_gif(result.frames[0], "result.gif") + ``` +""" + + +class AnimateDiffControlNetPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin, FreeNoiseMixin +): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: Union[UNet2DConditionModel, UNetMotionModel], + motion_adapter: MotionAdapter, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + + def decode_latents(self, latents, decode_batch_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_batch_size): + batch_latents = latents[i : i + decode_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + num_frames, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(image, list): + raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}") + if len(image) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list) or not isinstance(image[0], list): + raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}") + if len(image[0]) != num_frames: + raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}") + if any(len(img) != len(image[0]) for img in image): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_video_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[PipelineImageInput] = None, + conditioning_frames: Optional[List[PipelineImageInput]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_batch_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. + Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding + if `do_classifier_free_guidance` is set to `True`. + If not provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets + are specified, images must be passed as a list such that each element of the list can be correctly + batched for input to a single ControlNet. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or + `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + num_frames=num_frames, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=conditioning_frames, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + if isinstance(controlnet, ControlNetModel): + conditioning_frames = self.prepare_image( + image=conditioning_frames, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + cond_prepared_frames = [] + for frame_ in conditioning_frames: + prepared_frame = self.prepare_image( + image=frame_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + cond_prepared_frames.append(prepared_frame) + conditioning_frames = cond_prepared_frames + else: + assert False + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + control_model_input = torch.transpose(control_model_input, 1, 2) + control_model_input = control_model_input.reshape( + (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=conditioning_frames, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_batch_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) From 610f433d1c1b667487d465d6b5b52d2b8fcc55c8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:15:34 +0200 Subject: [PATCH 04/35] revert attention changes --- src/diffusers/models/attention.py | 235 +++++++++--------------------- 1 file changed, 71 insertions(+), 164 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 65b4d67bb8ff..b77af6e05d1a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -296,6 +296,17 @@ def __init__( attention_out_bias: bool = True, ): super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings self.only_cross_attention = only_cross_attention # We keep these boolean flags for backward-compatibility. @@ -446,185 +457,81 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - # TODO: This is just for trial because it is pretty hard to pass the free noise args from pipeline to here - # without making too many changes - context_length = 16 - context_stride = 4 - hardcoded_num_frames = 80 - # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention - # print("num_frames:", hidden_states.shape) - - if hidden_states.size(1) == hardcoded_num_frames: - # new implementation - hidden_states_original = hidden_states # [bhw, f, c] - frame_indices = get_frame_indices(hardcoded_num_frames, context_length, context_stride) - num_times_processed = torch.zeros((1, hardcoded_num_frames, 1), device=hidden_states_original.device) - processed_values = torch.zeros_like(hidden_states_original) - - for frame_start, frame_end in frame_indices: - weights = get_frame_weights(context_length, weight_type="pyramid") - weights_tensor = torch.ones_like(num_times_processed[:, frame_start : frame_end]) - weights_tensor *= torch.tensor(weights, device=hidden_states_original.device, dtype=hidden_states_original.dtype).unsqueeze(0).unsqueeze(-1) - - hidden_states = hidden_states_original[:, frame_start : frame_end] - - # Copied original implementation of norm1 and attn1 - batch_size = hidden_states.shape[0] - - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif self.norm_type == "ada_norm_single": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + batch_size = hidden_states.shape[0] - if self.norm_type == "ada_norm_zero": - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.norm_type == "ada_norm_single": - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1.2 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - - # 3. Cross-Attention - if self.attn2 is not None: - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) - else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - processed_values[:, frame_start : frame_end] += hidden_states * weights_tensor - num_times_processed[:, frame_start : frame_end] += weights_tensor - - hidden_states = torch.where(num_times_processed > 0, processed_values / num_times_processed, processed_values).to(hidden_states_original.dtype) - + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa else: - # old implementation - batch_size = hidden_states.shape[0] + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) elif self.norm_type == "ada_norm_single": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) else: - raise ValueError("Incorrect norm used") + raise ValueError("Incorrect norm") - if self.pos_embed is not None: + if self.pos_embed is not None and self.norm_type != "ada_norm_single": norm_hidden_states = self.pos_embed(norm_hidden_states) - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - - attn_output = self.attn1( + attn_output = self.attn2( norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) - - if self.norm_type == "ada_norm_zero": - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.norm_type == "ada_norm_single": - attn_output = gate_msa * attn_output - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1.2 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - - # 3. Cross-Attention - if self.attn2 is not None: - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) - else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states # 4. Feed-forward # i2vgen doesn't have this norm 🤷‍♂️ From 10b65b310cda41fcbcedc0058d2b4b34ea58dd00 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:16:02 +0200 Subject: [PATCH 05/35] add freenoise --- .../models/unets/unet_motion_model.py | 283 +++++++++++++++++- .../animatediff/pipeline_animatediff.py | 4 +- .../pipeline_animatediff_controlnet.py | 4 +- .../animatediff/pipeline_animatediff_sdxl.py | 2 + .../pipeline_animatediff_sparsectrl.py | 3 + .../pipeline_animatediff_video2video.py | 2 + src/diffusers/pipelines/free_noise_utils.py | 96 +++++- src/diffusers/pipelines/pia/pipeline_pia.py | 2 + 8 files changed, 386 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 196f947d599b..890864942754 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -19,8 +19,10 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward, _chunked_feed_forward from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -33,7 +35,7 @@ IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) -from ..embeddings import TimestepEmbedding, Timesteps +from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn @@ -53,6 +55,281 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@maybe_allow_in_graph +class FreeNoiseTransformerBlock(nn.Module): + r""" + A FreeNoise Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + context_length: int = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + self.set_free_noise_properties(context_length, context_stride, weighting_scheme) + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: + frame_indices = [] + for i in range(0, num_frames - self.context_length + 1, self.context_stride): + window_start = i + window_end = min(num_frames, i + self.context_length) + frame_indices.append((window_start, window_end)) + return frame_indices + + def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: + if weighting_scheme == "pyramid": + if num_frames % 2 == 0: + # num_frames = 4 => [1, 2, 2, 1] + weights = list(range(1, num_frames // 2 + 1)) + weights = weights + weights[::-1] + else: + # num_frames = 5 => [1, 2, 3, 2, 1] + weights = list(range(1, num_frames // 2 + 1)) + weights = weights + [num_frames // 2 + 1] + weights[::-1] + else: + raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") + + return weights + + def set_free_noise_properties(self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid") -> None: + self.context_length = context_length + self.context_stride = context_stride + self.weighting_scheme = weighting_scheme + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + # hidden_states: [B x H x W, F, C] + device = hidden_states.device + dtype = hidden_states.dtype + + num_frames = hidden_states.size(1) + frame_indices = self._get_frame_indices(num_frames) + frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) + frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) + + num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) + accumulated_values = torch.zeros_like(hidden_states) + + for frame_start, frame_end in frame_indices: + # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle + # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or + # essentially a non-multiple of `context_length`. + weights = torch.ones_like(num_times_accumulated[:, frame_start : frame_end]) + weights *= frame_weights + + hidden_states_chunk = hidden_states[:, frame_start : frame_end] + + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + # assert self.norm_type == "layer_norm" + norm_hidden_states = self.norm1(hidden_states_chunk) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states_chunk = attn_output + hidden_states_chunk + if hidden_states_chunk.ndim == 4: + hidden_states_chunk = hidden_states_chunk.squeeze(1) + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states_chunk) + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states_chunk = attn_output + hidden_states_chunk + + accumulated_values[:, frame_start : frame_end] += hidden_states_chunk * weights + num_times_accumulated[:, frame_start : frame_end] += weights + + hidden_states = torch.where(num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values).to(dtype) + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + class MotionModules(nn.Module): def __init__( self, diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index db5716860319..3a026235e5a4 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -42,7 +42,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin -from ..free_noise_utils import FreeNoiseMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -73,7 +73,7 @@ class AnimateDiffPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, - FreeNoiseMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 1a789316cb01..ebdb587d383f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -32,7 +32,7 @@ from ...video_processor import VideoProcessor from ..controlnet.multicontrolnet import MultiControlNetModel from ..free_init_utils import FreeInitMixin -from ..free_noise_utils import FreeNoiseMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -89,7 +89,7 @@ class AnimateDiffControlNetPipeline( - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin, FreeNoiseMixin + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin ): r""" Pipeline for text-to-video generation. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index a46682347519..4dfe16d883c6 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -56,6 +56,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -194,6 +195,7 @@ class AnimateDiffSDXLPipeline( TextualInversionLoaderMixin, IPAdapterMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation using Stable Diffusion XL. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index e9e0d518c806..00bc6b6d5ab8 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -38,6 +38,7 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -127,6 +128,8 @@ class AnimateDiffSparseControlNetPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, + ): r""" Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 8129b88dc408..c50eed2c6ea8 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -35,6 +35,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for video-to-video generation. diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index dff9ae8c2fa8..67a9cf610fd8 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,20 +12,110 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union +from ..models.attention import BasicTransformerBlock +from ..models.unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion, FreeNoiseTransformerBlock, TransformerTemporalModel, UpBlockMotion -class FreeNoiseMixin: + +class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" - def enable_free_noise(self, context_length: Optional[int] = 16, context_stride: int = 4, shuffle: bool = True) -> None: + def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): + r"""Helper function to enable FreeNoise in transformer blocks.""" + + for motion_module in block.motion_modules: + motion_module: TransformerTemporalModel + num_transformer_blocks = len(motion_module.transformer_blocks) + + for i in range(num_transformer_blocks): + if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): + motion_module.transformer_blocks[i].set_free_noise_properties(self._free_noise_context_length, self._free_noise_context_stride, self._free_noise_weighting_scheme) + else: + assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock) + basic_transfomer_block = motion_module.transformer_blocks[i] + + motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock( + dim=basic_transfomer_block.dim, + num_attention_heads=basic_transfomer_block.num_attention_heads, + attention_head_dim=basic_transfomer_block.attention_head_dim, + dropout=basic_transfomer_block.dropout, + cross_attention_dim=basic_transfomer_block.cross_attention_dim, + activation_fn=basic_transfomer_block.activation_fn, + attention_bias=basic_transfomer_block.attention_bias, + only_cross_attention=basic_transfomer_block.only_cross_attention, + double_self_attention=basic_transfomer_block.double_self_attention, + positional_embeddings=basic_transfomer_block.positional_embeddings, + num_positional_embeddings=basic_transfomer_block.num_positional_embeddings, + context_length=self._free_noise_context_length, + context_stride=self._free_noise_context_stride, + weighting_scheme=self._free_noise_weighting_scheme, + ).to(device=self.device, dtype=self.dtype) + + motion_module.transformer_blocks[i].load_state_dict(basic_transfomer_block.state_dict(), strict=True) + + def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): + r"""Helper function to disable FreeNoise in transformer blocks.""" + + for motion_module in block.motion_modules: + motion_module: TransformerTemporalModel + num_transformer_blocks = len(motion_module.transformer_blocks) + + for i in range(num_transformer_blocks): + if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): + free_noise_transfomer_block = motion_module.transformer_blocks[i] + + motion_module.transformer_blocks[i] = BasicTransformerBlock( + dim=free_noise_transfomer_block.dim, + num_attention_heads=free_noise_transfomer_block.num_attention_heads, + attention_head_dim=free_noise_transfomer_block.attention_head_dim, + dropout=free_noise_transfomer_block.dropout, + cross_attention_dim=free_noise_transfomer_block.cross_attention_dim, + activation_fn=free_noise_transfomer_block.activation_fn, + attention_bias=free_noise_transfomer_block.attention_bias, + only_cross_attention=free_noise_transfomer_block.only_cross_attention, + double_self_attention=free_noise_transfomer_block.double_self_attention, + positional_embeddings=free_noise_transfomer_block.positional_embeddings, + num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings, + ).to(device=self.device, dtype=self.dtype) + + motion_module.transformer_blocks[i].load_state_dict(free_noise_transfomer_block.state_dict(), strict=True) + + def enable_free_noise(self, context_length: Optional[int] = 16, context_stride: int = 4, weighting_scheme: str = "pyramid", shuffle: bool = True) -> None: + r""" + Enable long video generation using FreeNoise. + + Args: + context_length (`int`, defaults to `16`, *optional*): + The number of video frames to process at once. It's recommended to set this to the maximum frames the + Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion adapter + config is used. + context_stride (`int`, *optional*): + Long videos are generated by processing many frames. FreeNoise processes these frames in sliding windows of + size `context_length`. Context stride allows you to specify how many frames to skip between each window. + For example, a context length of 16 and context stride of 4 would process 24 frames as: + [0, 15], [4, 19], [8, 23] (0-based indexing) + weighting_scheme (`str`, defaults to `4`): + TODO(aryan) + shuffle (`str`, defaults to `True`): + TODO(aryan): decide if this is even needed + """ self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length self._free_noise_context_stride = context_stride + self._free_noise_weighting_scheme = weighting_scheme self._free_noise_shuffle = shuffle + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + self._enable_free_noise_in_block(block) + def disable_free_noise(self) -> None: self._free_noise_context_length = None + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + self._disable_free_noise_in_block(block) + @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index f383af7cc182..e44eb186c38f 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -45,6 +45,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -131,6 +132,7 @@ class PIAPipeline( StableDiffusionLoraLoaderMixin, FromSingleFileMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation. From a41f843dbae928fd060a35f848331fa7f42cbb5c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:16:45 +0200 Subject: [PATCH 06/35] remove old helper functions --- src/diffusers/models/attention.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b77af6e05d1a..06bf152df049 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -209,30 +209,6 @@ def forward( return encoder_hidden_states, hidden_states -def get_frame_indices(num_frames: int, context_length: int = 16, context_stride: int = 4): - batch_indices = [] - for i in range(0, num_frames - context_length, context_stride): - window_start = i - window_end = min(num_frames, i + context_length) - batch_indices.append((window_start, window_end)) - return batch_indices - -def get_frame_weights(num_frames: int, weight_type: str = "pyramid"): - if weight_type == "pyramid": - if num_frames % 2 == 0: - # num_frames = 4 => [1, 2, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) - weights = weights + weights[::-1] - else: - # num_frames = 5 => [1, 2, 3, 2, 1] - weights = list(range(1, num_frames // 2 + 1)) - weights = weights + [num_frames // 2 + 1] + weights[::-1] - else: - raise ValueError(f"Invalid `weight_type`: {weight_type}") - - return weights - - @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" From f6897ae46a0f59a778ea40057075def6c448fc52 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:26:14 +0200 Subject: [PATCH 07/35] add decode batch size param to all pipelines --- .../pipelines/animatediff/pipeline_animatediff.py | 2 ++ .../animatediff/pipeline_animatediff_controlnet.py | 3 +++ .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 9 ++++++--- .../animatediff/pipeline_animatediff_sparsectrl.py | 9 ++++++--- .../animatediff/pipeline_animatediff_video2video.py | 9 ++++++--- src/diffusers/pipelines/pia/pipeline_pia.py | 9 ++++++--- 6 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 3a026235e5a4..cae505c04988 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -660,6 +660,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + decode_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index ebdb587d383f..ad0bc485a420 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -401,6 +401,7 @@ def prepare_ip_adapter_image_embeds( image_embeds = ip_adapter_image_embeds return image_embeds + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents @@ -808,6 +809,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + decode_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 4dfe16d883c6..7241a5efce38 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -608,8 +608,8 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape @@ -878,6 +878,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_batch_size: int = 16, ): r""" Function invoked when calling the pipeline for generation. @@ -1017,6 +1018,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + decode_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -1260,7 +1263,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, decode_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # cast back to fp16 if needed diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 00bc6b6d5ab8..fcff86aa36f7 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -451,8 +451,8 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape @@ -731,6 +731,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -809,6 +810,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + decode_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -999,7 +1002,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, decode_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 12. Offload all models diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index c50eed2c6ea8..aef3b4e49065 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -500,8 +500,8 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape @@ -749,6 +749,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -824,6 +825,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + decode_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -992,7 +995,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, decode_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index e44eb186c38f..53cda0aee9bb 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -409,8 +409,8 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape @@ -689,6 +689,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -765,6 +766,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + decode_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -933,7 +936,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, decode_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models From 024e2da8645dedac59b1e00378a67ca5d6b80907 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:27:28 +0200 Subject: [PATCH 08/35] make style --- .../models/unets/unet_motion_model.py | 36 +++++++------- .../animatediff/pipeline_animatediff.py | 12 ++--- .../pipeline_animatediff_controlnet.py | 46 +++++++++++------- .../pipeline_animatediff_sparsectrl.py | 1 - src/diffusers/pipelines/free_noise_utils.py | 48 +++++++++++++------ 5 files changed, 89 insertions(+), 54 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 890864942754..367cf62ad942 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward, _chunked_feed_forward @@ -130,7 +130,7 @@ def __init__( self.positional_embeddings = positional_embeddings self.num_positional_embeddings = num_positional_embeddings self.only_cross_attention = only_cross_attention - + self.set_free_noise_properties(context_length, context_stride, weighting_scheme) # We keep these boolean flags for backward-compatibility. @@ -204,7 +204,7 @@ def __init__( # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 - + def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: frame_indices = [] for i in range(0, num_frames - self.context_length + 1, self.context_stride): @@ -212,7 +212,7 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: window_end = min(num_frames, i + self.context_length) frame_indices.append((window_start, window_end)) return frame_indices - + def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: if weighting_scheme == "pyramid": if num_frames % 2 == 0: @@ -227,8 +227,10 @@ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") return weights - - def set_free_noise_properties(self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid") -> None: + + def set_free_noise_properties( + self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" + ) -> None: self.context_length = context_length self.context_stride = context_stride self.weighting_scheme = weighting_scheme @@ -251,9 +253,9 @@ def forward( if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - + # hidden_states: [B x H x W, F, C] device = hidden_states.device dtype = hidden_states.dtype @@ -262,7 +264,7 @@ def forward( frame_indices = self._get_frame_indices(num_frames) frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) - + num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) accumulated_values = torch.zeros_like(hidden_states) @@ -270,10 +272,10 @@ def forward( # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or # essentially a non-multiple of `context_length`. - weights = torch.ones_like(num_times_accumulated[:, frame_start : frame_end]) + weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) weights *= frame_weights - hidden_states_chunk = hidden_states[:, frame_start : frame_end] + hidden_states_chunk = hidden_states[:, frame_start:frame_end] # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention @@ -308,11 +310,13 @@ def forward( **cross_attention_kwargs, ) hidden_states_chunk = attn_output + hidden_states_chunk - - accumulated_values[:, frame_start : frame_end] += hidden_states_chunk * weights - num_times_accumulated[:, frame_start : frame_end] += weights - - hidden_states = torch.where(num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values).to(dtype) + + accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights + num_times_accumulated[:, frame_start:frame_end] += weights + + hidden_states = torch.where( + num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + ).to(dtype) # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index cae505c04988..d2220687d1ac 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -401,13 +401,13 @@ def decode_latents(self, latents, decode_batch_size: int = 16): batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - + video = [] for i in range(0, latents.shape[0], decode_batch_size): batch_latents = latents[i : i + decode_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) - + video = torch.cat(video) video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -522,22 +522,22 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) - + if self.free_noise_enabled and self._free_noise_shuffle: for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): # ensure window is within bounds window_start = max(0, i - self._free_noise_context_length) window_end = min(num_frames, window_start + self._free_noise_context_stride) window_length = window_end - window_start - + if window_length == 0: break - + indices = torch.LongTensor(list(range(window_start, window_end))) shuffled_indices = indices[torch.randperm(window_length, generator=generator)] # shuffle latents in every window - latents[:, :, window_start : window_end] = latents[:, :, shuffled_indices] + latents[:, :, window_start:window_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index ad0bc485a420..a0a172b059b6 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -27,7 +27,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..controlnet.multicontrolnet import MultiControlNetModel @@ -37,7 +37,6 @@ from .pipeline_output import AnimateDiffPipelineOutput - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -51,7 +50,9 @@ >>> motion_id = "guoyww/animatediff-motion-adapter-v1-5-2" >>> adapter = MotionAdapter.from_pretrained(motion_id) - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16) + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16 + ... ) >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" @@ -63,7 +64,12 @@ ... custom_pipeline="pipeline_animatediff_controlnet", ... ).to(device="cuda", dtype=torch.float16) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( - ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", + ... model_id, + ... subfolder="scheduler", + ... clip_sample=False, + ... timestep_spacing="linspace", + ... steps_offset=1, + ... beta_schedule="linear", ... ) >>> pipe.enable_vae_slicing() @@ -83,13 +89,20 @@ ... ) >>> from diffusers.utils import export_to_gif + >>> export_to_gif(result.frames[0], "result.gif") ``` """ class AnimateDiffControlNetPipeline( - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + LoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation. @@ -407,13 +420,13 @@ def decode_latents(self, latents, decode_batch_size: int = 16): batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - + video = [] for i in range(0, latents.shape[0], decode_batch_size): batch_latents = latents[i : i + decode_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) - + video = torch.cat(video) video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -769,17 +782,16 @@ def __call__( ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. - Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding - if `do_classifier_free_guidance` is set to `True`. - If not provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. conditioning_frames (`List[PipelineImageInput]`, *optional*): - The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets - are specified, images must be passed as a list such that each element of the list can be correctly - batched for input to a single ControlNet. + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple + ControlNets are specified, images must be passed as a list such that each element of the list can be + correctly batched for input to a single ControlNet. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or - `np.array`. + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. @@ -986,7 +998,7 @@ def __call__( self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - + # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index fcff86aa36f7..54d924c0817d 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -129,7 +129,6 @@ class AnimateDiffSparseControlNetPipeline( StableDiffusionLoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin, - ): r""" Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 67a9cf610fd8..7c6dd99e1895 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -15,7 +15,13 @@ from typing import Optional, Union from ..models.attention import BasicTransformerBlock -from ..models.unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion, FreeNoiseTransformerBlock, TransformerTemporalModel, UpBlockMotion +from ..models.unets.unet_motion_model import ( + CrossAttnDownBlockMotion, + DownBlockMotion, + FreeNoiseTransformerBlock, + TransformerTemporalModel, + UpBlockMotion, +) class AnimateDiffFreeNoiseMixin: @@ -30,11 +36,15 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow for i in range(num_transformer_blocks): if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): - motion_module.transformer_blocks[i].set_free_noise_properties(self._free_noise_context_length, self._free_noise_context_stride, self._free_noise_weighting_scheme) + motion_module.transformer_blocks[i].set_free_noise_properties( + self._free_noise_context_length, + self._free_noise_context_stride, + self._free_noise_weighting_scheme, + ) else: assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock) basic_transfomer_block = motion_module.transformer_blocks[i] - + motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock( dim=basic_transfomer_block.dim, num_attention_heads=basic_transfomer_block.num_attention_heads, @@ -52,8 +62,10 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow weighting_scheme=self._free_noise_weighting_scheme, ).to(device=self.device, dtype=self.dtype) - motion_module.transformer_blocks[i].load_state_dict(basic_transfomer_block.state_dict(), strict=True) - + motion_module.transformer_blocks[i].load_state_dict( + basic_transfomer_block.state_dict(), strict=True + ) + def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -64,7 +76,7 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do for i in range(num_transformer_blocks): if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): free_noise_transfomer_block = motion_module.transformer_blocks[i] - + motion_module.transformer_blocks[i] = BasicTransformerBlock( dim=free_noise_transfomer_block.dim, num_attention_heads=free_noise_transfomer_block.num_attention_heads, @@ -79,21 +91,29 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings, ).to(device=self.device, dtype=self.dtype) - motion_module.transformer_blocks[i].load_state_dict(free_noise_transfomer_block.state_dict(), strict=True) - - def enable_free_noise(self, context_length: Optional[int] = 16, context_stride: int = 4, weighting_scheme: str = "pyramid", shuffle: bool = True) -> None: + motion_module.transformer_blocks[i].load_state_dict( + free_noise_transfomer_block.state_dict(), strict=True + ) + + def enable_free_noise( + self, + context_length: Optional[int] = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + shuffle: bool = True, + ) -> None: r""" Enable long video generation using FreeNoise. Args: context_length (`int`, defaults to `16`, *optional*): The number of video frames to process at once. It's recommended to set this to the maximum frames the - Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion adapter - config is used. + Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion + adapter config is used. context_stride (`int`, *optional*): - Long videos are generated by processing many frames. FreeNoise processes these frames in sliding windows of - size `context_length`. Context stride allows you to specify how many frames to skip between each window. - For example, a context length of 16 and context stride of 4 would process 24 frames as: + Long videos are generated by processing many frames. FreeNoise processes these frames in sliding + windows of size `context_length`. Context stride allows you to specify how many frames to skip between + each window. For example, a context length of 16 and context stride of 4 would process 24 frames as: [0, 15], [4, 19], [8, 23] (0-based indexing) weighting_scheme (`str`, defaults to `4`): TODO(aryan) From 1bb09845bf358fd91ec7148df3d92569c6b7b8f0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:29:30 +0200 Subject: [PATCH 09/35] fix copied from comments --- .../pipelines/animatediff/pipeline_animatediff_controlnet.py | 2 +- .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 2 +- .../pipelines/animatediff/pipeline_animatediff_sparsectrl.py | 2 +- .../pipelines/animatediff/pipeline_animatediff_video2video.py | 2 +- src/diffusers/pipelines/pia/pipeline_pia.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index a0a172b059b6..3928c912bf15 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -414,7 +414,7 @@ def prepare_ip_adapter_image_embeds( image_embeds = ip_adapter_image_embeds return image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 7241a5efce38..aa770b852874 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -608,7 +608,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 54d924c0817d..b5440c549485 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -450,7 +450,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index aef3b4e49065..6c1d8239326f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -500,7 +500,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 53cda0aee9bb..6c8ec9ec1c16 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -409,7 +409,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffVideoToVideo.decode_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents From 1b7bc007d8778ae8a2c172fcd150ea33976a5eca Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:30:42 +0200 Subject: [PATCH 10/35] make fix-copies --- .../pipeline_animatediff_controlnet.py | 58 +++++++++++-------- .../animatediff/pipeline_animatediff_sdxl.py | 10 +++- .../pipeline_animatediff_sparsectrl.py | 10 +++- .../pipeline_animatediff_video2video.py | 10 +++- src/diffusers/pipelines/pia/pipeline_pia.py | 10 +++- 5 files changed, 67 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 3928c912bf15..d14fc2c9d82f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -18,11 +18,11 @@ import numpy as np import torch import torch.nn.functional as F -from PIL import Image +import PIL.Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput -from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter @@ -100,7 +100,7 @@ class AnimateDiffControlNetPipeline( StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, - LoraLoaderMixin, + StableDiffusionLoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin, ): @@ -216,7 +216,7 @@ def encode_prompt( """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -348,9 +348,10 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds @@ -381,8 +382,11 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -392,7 +396,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -400,19 +403,28 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if self.do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: - image_embeds = ip_adapter_image_embeds - return image_embeds + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16): @@ -593,10 +605,10 @@ def check_inputs( # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, Image.Image) + image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) @@ -668,7 +680,7 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_video_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index aa770b852874..bc067bf867ed 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -615,8 +615,14 @@ def decode_latents(self, latents, decode_batch_size: int = 16): batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], decode_batch_size): + batch_latents = latents[i : i + decode_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index b5440c549485..d52a4dd385ee 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -457,8 +457,14 @@ def decode_latents(self, latents, decode_batch_size: int = 16): batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], decode_batch_size): + batch_latents = latents[i : i + decode_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 6c1d8239326f..89ede2850d71 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -507,8 +507,14 @@ def decode_latents(self, latents, decode_batch_size: int = 16): batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], decode_batch_size): + batch_latents = latents[i : i + decode_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 6c8ec9ec1c16..3699410f2a09 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -416,8 +416,14 @@ def decode_latents(self, latents, decode_batch_size: int = 16): batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], decode_batch_size): + batch_latents = latents[i : i + decode_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video From dc96a8d5cd32d6d6d5d7e46e2860ab79efde2a51 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:31:12 +0200 Subject: [PATCH 11/35] make style --- .../pipelines/animatediff/pipeline_animatediff_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index d14fc2c9d82f..30ebe98ae948 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -16,9 +16,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import PIL.Image import torch import torch.nn.functional as F -import PIL.Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput From 691facfc2e63a4e0904246224b465856d5ccb8bb Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 15:34:39 +0200 Subject: [PATCH 12/35] copy animatediff controlnet implementation from #8972 --- .../pipeline_animatediff_controlnet.py | 197 ++++++++---------- 1 file changed, 83 insertions(+), 114 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 30ebe98ae948..2aa7a7e02a88 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -15,8 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np -import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -43,54 +41,64 @@ Examples: ```py >>> import torch - >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter - >>> from diffusers.pipelines import DiffusionPipeline - >>> from diffusers.schedulers import DPMSolverMultistepScheduler - >>> from PIL import Image - - >>> motion_id = "guoyww/animatediff-motion-adapter-v1-5-2" - >>> adapter = MotionAdapter.from_pretrained(motion_id) - >>> controlnet = ControlNetModel.from_pretrained( - ... "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16 + >>> from diffusers import ( + ... AnimateDiffControlNetPipeline, + ... AutoencoderKL, + ... ControlNetModel, + ... MotionAdapter, + ... LCMScheduler, ... ) - >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + >>> from diffusers.utils import export_to_gif, load_video + + >>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet + >>> # HF maintains just the right package for it: `pip install controlnet_aux` + >>> from controlnet_aux.processor import ZoeDetector + + >>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file + >>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained() + >>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16) + + >>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3) + >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") - >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" - >>> pipe = DiffusionPipeline.from_pretrained( - ... model_id, - ... motion_adapter=adapter, + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + >>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", + ... motion_adapter=motion_adapter, ... controlnet=controlnet, ... vae=vae, - ... custom_pipeline="pipeline_animatediff_controlnet", ... ).to(device="cuda", dtype=torch.float16) - >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( - ... model_id, - ... subfolder="scheduler", - ... clip_sample=False, - ... timestep_spacing="linspace", - ... steps_offset=1, - ... beta_schedule="linear", + >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + >>> pipe.load_lora_weights( + ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" ... ) - >>> pipe.enable_vae_slicing() + >>> pipe.set_adapters(["lcm-lora"], [0.8]) + >>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" + ... ) >>> conditioning_frames = [] - >>> for i in range(1, 16 + 1): - ... conditioning_frames.append(Image.open(f"frame_{i}.png")) - >>> prompt = "astronaut in space, dancing" - >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" - >>> result = pipe( + >>> with pipe.progress_bar(total=len(video)) as progress_bar: + ... for frame in video: + ... conditioning_frames.append(depth_detector(frame)) + ... progress_bar.update() + + >>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality" + >>> negative_prompt = "bad quality, worst quality" + + >>> video = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, - ... width=512, - ... height=768, + ... num_frames=len(video), + ... num_inference_steps=10, + ... guidance_scale=2.0, ... conditioning_frames=conditioning_frames, - ... num_inference_steps=12, - ... ) + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] - >>> from diffusers.utils import export_to_gif - - >>> export_to_gif(result.frames[0], "result.gif") + >>> export_to_gif(video, "animatediff_controlnet.gif", fps=8) ``` """ @@ -105,15 +113,15 @@ class AnimateDiffControlNetPipeline( AnimateDiffFreeNoiseMixin, ): r""" - Pipeline for text-to-video generation. + Pipeline for text-to-video generation with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: @@ -473,7 +481,7 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, - image=None, + video=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, @@ -532,20 +540,20 @@ def check_inputs( or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): - if not isinstance(image, list): - raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}") - if len(image) != num_frames: - raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}") + if not isinstance(video, list): + raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}") + if len(video) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): - if not isinstance(image, list) or not isinstance(image[0], list): - raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}") - if len(image[0]) != num_frames: - raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}") - if any(len(img) != len(image[0]) for img in image): + if not isinstance(video, list) or not isinstance(video[0], list): + raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}") + if len(video[0]) != num_frames: + raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}") + if any(len(img) != len(video[0]) for img in video): raise ValueError("All conditioning frame batches for multicontrolnet must be same size") else: assert False @@ -603,44 +611,6 @@ def check_inputs( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None @@ -667,36 +637,37 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image - def prepare_image( + def prepare_video( self, - image, + video, width, height, batch_size, - num_images_per_prompt, + num_videos_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] + video = self.control_video_processor.preprocess_video(video, height=height, width=width).to( + dtype=torch.float32 + ) + video = video.permute(0, 2, 1, 3, 4).flatten(0, 1) + video_batch_size = video.shape[0] - if image_batch_size == 1: + if video_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt + repeat_by = num_videos_per_prompt - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) + video = video.repeat_interleave(repeat_by, dim=0) + video = video.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) + video = torch.cat([video] * 2) - return image + return video @property def guidance_scale(self): @@ -833,8 +804,6 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - decode_batch_size (`int`, defaults to `16`): - The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -873,7 +842,7 @@ def __call__( callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - image=conditioning_frames, + video=conditioning_frames, controlnet_conditioning_scale=controlnet_conditioning_scale, control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end, @@ -934,33 +903,33 @@ def __call__( ) if isinstance(controlnet, ControlNetModel): - conditioning_frames = self.prepare_image( - image=conditioning_frames, + conditioning_frames = self.prepare_video( + video=conditioning_frames, width=width, height=height, batch_size=batch_size * num_videos_per_prompt * num_frames, - num_images_per_prompt=num_videos_per_prompt, + num_videos_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): - cond_prepared_frames = [] + cond_prepared_videos = [] for frame_ in conditioning_frames: - prepared_frame = self.prepare_image( - image=frame_, + prepared_video = self.prepare_video( + video=frame_, width=width, height=height, batch_size=batch_size * num_videos_per_prompt * num_frames, - num_images_per_prompt=num_videos_per_prompt, + num_videos_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) - cond_prepared_frames.append(prepared_frame) - conditioning_frames = cond_prepared_frames + cond_prepared_videos.append(prepared_video) + conditioning_frames = cond_prepared_videos else: assert False @@ -987,7 +956,7 @@ def __call__( # 7. Add image embeds for IP-Adapter added_cond_kwargs = ( - {"image_embeds": ip_adapter_image_embeds} + {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) From 5a60a62c4708163028404853257ad0995fdf8547 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 27 Jul 2024 16:24:36 +0200 Subject: [PATCH 13/35] add experimental support for num_frames not perfectly fitting context length, ocntext stride --- .../models/unets/unet_motion_model.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 367cf62ad942..45f770a7ef09 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -211,6 +211,7 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: window_start = i window_end = min(num_frames, i + self.context_length) frame_indices.append((window_start, window_end)) + return frame_indices def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: @@ -264,11 +265,21 @@ def forward( frame_indices = self._get_frame_indices(num_frames) frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) + is_last_frame_batch_complete = frame_indices[-1][1] == num_frames + + # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length + # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: + # [(0, 16), (4, 20), (8, 24), (10, 26)] + if not is_last_frame_batch_complete: + if num_frames < self.context_length: + raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") + last_frame_batch_length = num_frames - frame_indices[-1][1] + frame_indices.append((num_frames - self.context_length, num_frames)) num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) accumulated_values = torch.zeros_like(hidden_states) - for frame_start, frame_end in frame_indices: + for i, (frame_start, frame_end) in enumerate(frame_indices): # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or # essentially a non-multiple of `context_length`. @@ -311,8 +322,14 @@ def forward( ) hidden_states_chunk = attn_output + hidden_states_chunk - accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights - num_times_accumulated[:, frame_start:frame_end] += weights + if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: + accumulated_values[:, -last_frame_batch_length:] += ( + hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] + ) + num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] + else: + accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights + num_times_accumulated[:, frame_start:frame_end] += weights hidden_states = torch.where( num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values From 58c2ddcb39dec2152b425be9765b589a6743529c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 01:20:12 +0200 Subject: [PATCH 14/35] make unet motion model lora work again based on #8995 --- src/diffusers/loaders/peft.py | 1 + src/diffusers/models/unets/unet_motion_model.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 5625f9755b19..bca983596b3f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -31,6 +31,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { "UNet2DConditionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, + "UNetMotionModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 45f770a7ef09..ae5a35bbb0f1 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward, _chunked_feed_forward @@ -529,7 +529,7 @@ def forward(self, sample): pass -class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): r""" A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. From 70001864871a90872bc5f7d514915305b339a4d7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 01:20:35 +0200 Subject: [PATCH 15/35] copy load video utils from #8972 --- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/loading_utils.py | 77 ++++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f41edfcda3d8..d2633d2ec9e7 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -93,7 +93,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image +from .loading_utils import load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index aa087e981731..62a854cf5299 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,13 +1,16 @@ import os -from typing import Callable, Union +import tempfile +from typing import Callable, List, Optional, Union import PIL.Image import PIL.ImageOps import requests +from .import_utils import BACKENDS_MAPPING, is_opencv_available + def load_image( - image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None + image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None ) -> PIL.Image.Image: """ Loads `image` to a PIL Image. @@ -15,7 +18,7 @@ def load_image( Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. - convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional): + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): A conversion method to apply to the image after loading it. When set to `None` the image will be converted "RGB". @@ -47,3 +50,71 @@ def load_image( image = image.convert("RGB") return image + + +def load_video( + video: Union[str], convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None +) -> List[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + + Args: + video (`str`): + The video URL, or path to local file, to load and convert to a list of PIL Images. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + if isinstance(video, str): + was_tempfile_created = False + + if video.startswith("http://") or video.startswith("https://"): + video_data = requests.get(video, stream=True).raw + video_path = tempfile.NamedTemporaryFile(suffix=os.path.splitext(video)[1], delete=False).name + was_tempfile_created = True + with open(video_path, "wb") as f: + f.write(video_data.read()) + video = video_path + elif not os.path.isfile(video): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if video.endswith(".gif"): + pil_images = [] + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + else: + if is_opencv_available(): + import cv2 + else: + raise ImportError(BACKENDS_MAPPING["opencv"][1].format("load_video")) + pil_images = [] + video_capture = cv2.VideoCapture(video) + success, frame = video_capture.read() + while success: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_images.append(PIL.Image.fromarray(frame)) + success, frame = video_capture.read() + video_capture.release() + + if was_tempfile_created: + os.remove(video_path) + else: + raise ValueError("Incorrect format used for the video. Expected a URL or a local path.") + + if convert_method is not None: + pil_images = convert_method(pil_images) + else: + pil_images = [image.convert("RGB") for image in pil_images] + + return pil_images From c5db39f88364217dcae592c4e7ec14365c4988d5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 02:29:13 +0200 Subject: [PATCH 16/35] copied from AnimateDiff::prepare_latents --- .../animatediff/pipeline_animatediff.py | 7 +++++-- .../pipeline_animatediff_controlnet.py | 21 ++++++++++++++++++- .../animatediff/pipeline_animatediff_sdxl.py | 21 ++++++++++++++++++- .../pipeline_animatediff_sparsectrl.py | 21 ++++++++++++++++++- src/diffusers/pipelines/pia/pipeline_pia.py | 21 ++++++++++++++++++- 5 files changed, 85 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index d2220687d1ac..f80b7753a5a1 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -523,6 +523,8 @@ def prepare_latents( else: latents = latents.to(device) + # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of + # [FreeNoise](https://arxiv.org/abs/2310.15169) if self.free_noise_enabled and self._free_noise_shuffle: for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): # ensure window is within bounds @@ -536,8 +538,9 @@ def prepare_latents( indices = torch.LongTensor(list(range(window_start, window_end))) shuffled_indices = indices[torch.randperm(window_length, generator=generator)] - # shuffle latents in every window - latents[:, :, window_start:window_end] = latents[:, :, shuffled_indices] + current_start = i + current_end = min(num_frames, i + self._free_noise_context_stride) + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 2aa7a7e02a88..246d57890a6f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -611,7 +611,7 @@ def check_inputs( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -633,6 +633,25 @@ def prepare_latents( else: latents = latents.to(device) + # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of + # [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled and self._free_noise_shuffle: + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + current_start = i + current_end = min(num_frames, i + self._free_noise_context_stride) + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index bc067bf867ed..2b794d9c8492 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -717,7 +717,7 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -739,6 +739,25 @@ def prepare_latents( else: latents = latents.to(device) + # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of + # [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled and self._free_noise_shuffle: + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + current_start = i + current_end = min(num_frames, i + self._free_noise_context_stride) + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index d52a4dd385ee..5dfc942ab3e7 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -619,7 +619,7 @@ def check_image(self, image, prompt, prompt_embeds): f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -641,6 +641,25 @@ def prepare_latents( else: latents = latents.to(device) + # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of + # [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled and self._free_noise_shuffle: + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + current_start = i + current_end = min(num_frames, i + self._free_noise_context_stride) + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 3699410f2a09..f02a43d6352a 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -555,7 +555,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -577,6 +577,25 @@ def prepare_latents( else: latents = latents.to(device) + # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of + # [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled and self._free_noise_shuffle: + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + current_start = i + current_end = min(num_frames, i + self._free_noise_context_stride) + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents From 594d2d2c7ccb628456aa7c51874c811e3e96a59c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 03:05:55 +0200 Subject: [PATCH 17/35] address the case where last batch of frames does not match length of indices in prepare latents --- .../pipelines/animatediff/pipeline_animatediff.py | 11 +++++++++-- .../animatediff/pipeline_animatediff_controlnet.py | 11 +++++++++-- .../animatediff/pipeline_animatediff_sdxl.py | 11 +++++++++-- .../animatediff/pipeline_animatediff_sparsectrl.py | 11 +++++++++-- src/diffusers/pipelines/pia/pipeline_pia.py | 11 +++++++++-- 5 files changed, 45 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index f80b7753a5a1..275f65d39baf 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -539,8 +539,15 @@ def prepare_latents( shuffled_indices = indices[torch.randperm(window_length, generator=generator)] current_start = i - current_end = min(num_frames, i + self._free_noise_context_stride) - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 246d57890a6f..f27a4c100d22 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -649,8 +649,15 @@ def prepare_latents( shuffled_indices = indices[torch.randperm(window_length, generator=generator)] current_start = i - current_end = min(num_frames, i + self._free_noise_context_stride) - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 2b794d9c8492..30be48ac4800 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -755,8 +755,15 @@ def prepare_latents( shuffled_indices = indices[torch.randperm(window_length, generator=generator)] current_start = i - current_end = min(num_frames, i + self._free_noise_context_stride) - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 5dfc942ab3e7..f7fb4e5ebb2d 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -657,8 +657,15 @@ def prepare_latents( shuffled_indices = indices[torch.randperm(window_length, generator=generator)] current_start = i - current_end = min(num_frames, i + self._free_noise_context_stride) - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index f02a43d6352a..8666d58b01f6 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -593,8 +593,15 @@ def prepare_latents( shuffled_indices = indices[torch.randperm(window_length, generator=generator)] current_start = i - current_end = min(num_frames, i + self._free_noise_context_stride) - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma From fb9ca3471839f628b48cbeb56f40e559f0ed9fc1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 04:23:09 +0200 Subject: [PATCH 18/35] decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid --- .../animatediff/pipeline_animatediff.py | 12 ++++---- .../pipeline_animatediff_controlnet.py | 10 +++---- .../animatediff/pipeline_animatediff_sdxl.py | 12 ++++---- .../pipeline_animatediff_sparsectrl.py | 12 ++++---- .../pipeline_animatediff_video2video.py | 29 ++++++++++++------- src/diffusers/pipelines/pia/pipeline_pia.py | 12 ++++---- 6 files changed, 47 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 275f65d39baf..05812f8acd86 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -396,15 +396,15 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - def decode_latents(self, latents, decode_batch_size: int = 16): + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -601,7 +601,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, **kwargs, ): r""" @@ -670,7 +670,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - decode_batch_size (`int`, defaults to `16`): + vae_batch_size (`int`, defaults to `16`): The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -843,7 +843,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index f27a4c100d22..ace61ea1e452 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -435,15 +435,15 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, decode_batch_size: int = 16): + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -747,7 +747,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -1083,7 +1083,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 30be48ac4800..e1b88ab59974 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -609,15 +609,15 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, decode_batch_size: int = 16): + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -910,7 +910,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, ): r""" Function invoked when calling the pipeline for generation. @@ -1050,7 +1050,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - decode_batch_size (`int`, defaults to `16`): + vae_batch_size (`int`, defaults to `16`): The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -1295,7 +1295,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # cast back to fp16 if needed diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index f7fb4e5ebb2d..0c92d4ce0366 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -451,15 +451,15 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, decode_batch_size: int = 16): + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -762,7 +762,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -841,7 +841,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - decode_batch_size (`int`, defaults to `16`): + vae_batch_size (`int`, defaults to `16`): The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -1033,7 +1033,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 12. Offload all models diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 89ede2850d71..190170abd868 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -500,16 +500,24 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor: + latents = [] + for i in range(0, len(video), vae_batch_size): + batch_video = video[i : i + vae_batch_size] + batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) + latents.append(batch_video) + return torch.cat(latents) + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, decode_batch_size: int = 16): + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -630,6 +638,7 @@ def prepare_latents( device, generator, latents=None, + vae_batch_size: int = 16, ): if latents is None: num_frames = video.shape[1] @@ -664,13 +673,10 @@ def prepare_latents( ) init_latents = [ - retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0) - for i in range(batch_size) + self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size) ] else: - init_latents = [ - retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video - ] + init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video] init_latents = torch.cat(init_latents, dim=0) @@ -755,7 +761,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -831,7 +837,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - decode_batch_size (`int`, defaults to `16`): + vae_batch_size (`int`, defaults to `16`): The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -934,6 +940,7 @@ def __call__( device=device, generator=generator, latents=latents, + vae_batch_size=vae_batch_size, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -1001,7 +1008,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 8666d58b01f6..70177b2a5948 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -410,15 +410,15 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, decode_batch_size: int = 16): + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -721,7 +721,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -798,7 +798,7 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - decode_batch_size (`int`, defaults to `16`): + vae_batch_size (`int`, defaults to `16`): The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -968,7 +968,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models From 77ee296a42a3364663b5a30b32feabb6122f4d60 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 05:15:30 +0200 Subject: [PATCH 19/35] revert sparsectrl and sdxl freenoise changes --- .../animatediff/pipeline_animatediff_sdxl.py | 49 +++---------------- .../pipeline_animatediff_sparsectrl.py | 49 +++---------------- 2 files changed, 12 insertions(+), 86 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index e1b88ab59974..a46682347519 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -56,7 +56,6 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin -from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -195,7 +194,6 @@ class AnimateDiffSDXLPipeline( TextualInversionLoaderMixin, IPAdapterMixin, FreeInitMixin, - AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation using Stable Diffusion XL. @@ -608,21 +606,15 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, vae_batch_size: int = 16): + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - video = [] - for i in range(0, latents.shape[0], vae_batch_size): - batch_latents = latents[i : i + vae_batch_size] - batch_latents = self.vae.decode(batch_latents).sample - video.append(batch_latents) - - video = torch.cat(video) - video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -717,7 +709,7 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -739,32 +731,6 @@ def prepare_latents( else: latents = latents.to(device) - # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of - # [FreeNoise](https://arxiv.org/abs/2310.15169) - if self.free_noise_enabled and self._free_noise_shuffle: - for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): - # ensure window is within bounds - window_start = max(0, i - self._free_noise_context_length) - window_end = min(num_frames, window_start + self._free_noise_context_stride) - window_length = window_end - window_start - - if window_length == 0: - break - - indices = torch.LongTensor(list(range(window_start, window_end))) - shuffled_indices = indices[torch.randperm(window_length, generator=generator)] - - current_start = i - current_end = min(num_frames, current_start + window_length) - if current_end == current_start + window_length: - # batch of frames perfectly fits the window - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - else: - # handle the case where the last batch of frames does not fit perfectly with the window - prefix_length = current_end - current_start - shuffled_indices = shuffled_indices[:prefix_length] - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @@ -910,7 +876,6 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - vae_batch_size: int = 16, ): r""" Function invoked when calling the pipeline for generation. @@ -1050,8 +1015,6 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - vae_batch_size (`int`, defaults to `16`): - The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -1295,7 +1258,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, vae_batch_size) + video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # cast back to fp16 if needed diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 0c92d4ce0366..e9e0d518c806 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -38,7 +38,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin -from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -128,7 +127,6 @@ class AnimateDiffSparseControlNetPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, - AnimateDiffFreeNoiseMixin, ): r""" Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls @@ -450,21 +448,15 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, vae_batch_size: int = 16): + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - video = [] - for i in range(0, latents.shape[0], vae_batch_size): - batch_latents = latents[i : i + vae_batch_size] - batch_latents = self.vae.decode(batch_latents).sample - video.append(batch_latents) - - video = torch.cat(video) - video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -619,7 +611,7 @@ def check_image(self, image, prompt, prompt_embeds): f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -641,32 +633,6 @@ def prepare_latents( else: latents = latents.to(device) - # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of - # [FreeNoise](https://arxiv.org/abs/2310.15169) - if self.free_noise_enabled and self._free_noise_shuffle: - for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): - # ensure window is within bounds - window_start = max(0, i - self._free_noise_context_length) - window_end = min(num_frames, window_start + self._free_noise_context_stride) - window_length = window_end - window_start - - if window_length == 0: - break - - indices = torch.LongTensor(list(range(window_start, window_end))) - shuffled_indices = indices[torch.randperm(window_length, generator=generator)] - - current_start = i - current_end = min(num_frames, current_start + window_length) - if current_end == current_start + window_length: - # batch of frames perfectly fits the window - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - else: - # handle the case where the last batch of frames does not fit perfectly with the window - prefix_length = current_end - current_start - shuffled_indices = shuffled_indices[:prefix_length] - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @@ -762,7 +728,6 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -841,8 +806,6 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - vae_batch_size (`int`, defaults to `16`): - The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -1033,7 +996,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, vae_batch_size) + video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 12. Offload all models From 52884b3e641b6cf2922a636fea7b7d5a32c8a224 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 05:26:36 +0200 Subject: [PATCH 20/35] revert pia --- src/diffusers/pipelines/pia/pipeline_pia.py | 49 +++------------------ 1 file changed, 6 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 70177b2a5948..f383af7cc182 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -45,7 +45,6 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin -from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -132,7 +131,6 @@ class PIAPipeline( StableDiffusionLoraLoaderMixin, FromSingleFileMixin, FreeInitMixin, - AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation. @@ -409,21 +407,15 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents - def decode_latents(self, latents, vae_batch_size: int = 16): + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - video = [] - for i in range(0, latents.shape[0], vae_batch_size): - batch_latents = latents[i : i + vae_batch_size] - batch_latents = self.vae.decode(batch_latents).sample - video.append(batch_latents) - - video = torch.cat(video) - video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -555,7 +547,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): @@ -577,32 +569,6 @@ def prepare_latents( else: latents = latents.to(device) - # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of - # [FreeNoise](https://arxiv.org/abs/2310.15169) - if self.free_noise_enabled and self._free_noise_shuffle: - for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): - # ensure window is within bounds - window_start = max(0, i - self._free_noise_context_length) - window_end = min(num_frames, window_start + self._free_noise_context_stride) - window_length = window_end - window_start - - if window_length == 0: - break - - indices = torch.LongTensor(list(range(window_start, window_end))) - shuffled_indices = indices[torch.randperm(window_length, generator=generator)] - - current_start = i - current_end = min(num_frames, current_start + window_length) - if current_end == current_start + window_length: - # batch of frames perfectly fits the window - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - else: - # handle the case where the last batch of frames does not fit perfectly with the window - prefix_length = current_end - current_start - shuffled_indices = shuffled_indices[:prefix_length] - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @@ -721,7 +687,6 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -798,8 +763,6 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - vae_batch_size (`int`, defaults to `16`): - The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -968,7 +931,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, vae_batch_size) + video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models From 1e2ef4dfc2d15735ae79172bef96e02d269ae1a2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 28 Jul 2024 06:02:05 +0200 Subject: [PATCH 21/35] add freenoise tests --- .../pipelines/animatediff/test_animatediff.py | 59 ++++++++++++++++ .../test_animatediff_video2video.py | 69 ++++++++++++++++++- 2 files changed, 125 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index dd636b7ce669..3db63767a544 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -17,6 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.unets.unet_motion_model import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device @@ -401,6 +402,64 @@ def test_free_init_with_schedulers(self): "Enabling of FreeInit should lead to results different from the default pipeline results", ) + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device) + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + + inputs_disable_free_noise = self.get_dummy_inputs(torch_device) + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index ced042b4a702..375e8c9cbb8e 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -17,6 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.unets.unet_motion_model import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging from diffusers.utils.testing_utils import torch_device @@ -114,7 +115,7 @@ def get_dummy_components(self): } return components - def get_dummy_inputs(self, device, seed=0): + def get_dummy_inputs(self, device, seed=0, num_frames: int = 2): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -122,8 +123,7 @@ def get_dummy_inputs(self, device, seed=0): video_height = 32 video_width = 32 - video_num_frames = 2 - video = [Image.new("RGB", (video_width, video_height))] * video_num_frames + video = [Image.new("RGB", (video_width, video_height))] * num_frames inputs = { "video": video, @@ -428,3 +428,66 @@ def test_free_init_with_schedulers(self): 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results", ) + + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_normal["num_inference_steps"] = 2 + inputs_normal["strength"] = 0.5 + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_enable_free_noise["num_inference_steps"] = 2 + inputs_enable_free_noise["strength"] = 0.5 + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_disable_free_noise["num_inference_steps"] = 2 + inputs_disable_free_noise["strength"] = 0.5 + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) From 3d9b183311096cc3e9e3f360c6f605d1570a7c4b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 13:47:53 +0200 Subject: [PATCH 22/35] make fix-copies --- src/diffusers/loaders/peft.py | 1 - .../animatediff/pipeline_animatediff_controlnet.py | 11 ++++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 1d72aa9861f2..fd6c639a7cdf 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -32,7 +32,6 @@ "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, - "UNetMotionModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 5574556a9b96..cdb6299e85e1 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -432,15 +432,16 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - def decode_latents(self, latents, decode_batch_size: int = 16): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -718,7 +719,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -1054,7 +1055,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models From 44e40a2890de2e31981107fcc232568345a42966 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 14:05:31 +0200 Subject: [PATCH 23/35] improve docstrings --- .../models/unets/unet_motion_model.py | 58 ++++++++++++------- src/diffusers/pipelines/free_noise_utils.py | 14 +++-- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index ae5a35bbb0f1..1a6fa76c8dae 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -61,34 +61,54 @@ class FreeNoiseTransformerBlock(nn.Module): A FreeNoise Transformer block. Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (`int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (`bool`, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, defaults to `False`): Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): + double_self_attention (`bool`, defaults to `False`): Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): + upcast_attention (`bool`, defaults to `False`): Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + norm_elementwise_affine (`bool`, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): + norm_type (`str`, defaults to `"layer_norm"`): The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): + final_dropout (`bool` defaults to `False`): Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): + attention_type (`str`, defaults to `"default"`): The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*, defaults to `None`): + positional_embeddings (`str`, *optional*): The type of positional embeddings to apply to. num_positional_embeddings (`int`, *optional*, defaults to `None`): The maximum number of positional embeddings to apply. + ff_inner_dim (`int`, *optional*): + Hidden dimension of feed-forward MLP. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in feed-forward MLP. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in attention output project layer. + context_length (`int`, defaults to `16`): + The maximum number of frames that the FreeNoise block processes at once. + context_stride (`int`, defaults to `4`): + The number of frames to be skipped before starting to process a new batch of `context_length` frames. + weighting_scheme (`str`, defaults to `"pyramid"`): + The weighting scheme to use for weighting averaging of processed latent frames. As described in the + Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting + used. """ def __init__( @@ -290,7 +310,6 @@ def forward( # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention - # assert self.norm_type == "layer_norm" norm_hidden_states = self.norm1(hidden_states_chunk) if self.pos_embed is not None: @@ -339,7 +358,6 @@ def forward( norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 7c6dd99e1895..160a7ba1fcd2 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -19,7 +19,6 @@ CrossAttnDownBlockMotion, DownBlockMotion, FreeNoiseTransformerBlock, - TransformerTemporalModel, UpBlockMotion, ) @@ -31,7 +30,6 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow r"""Helper function to enable FreeNoise in transformer blocks.""" for motion_module in block.motion_modules: - motion_module: TransformerTemporalModel num_transformer_blocks = len(motion_module.transformer_blocks) for i in range(num_transformer_blocks): @@ -70,7 +68,6 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do r"""Helper function to disable FreeNoise in transformer blocks.""" for motion_module in block.motion_modules: - motion_module: TransformerTemporalModel num_transformer_blocks = len(motion_module.transformer_blocks) for i in range(num_transformer_blocks): @@ -115,10 +112,15 @@ def enable_free_noise( windows of size `context_length`. Context stride allows you to specify how many frames to skip between each window. For example, a context length of 16 and context stride of 4 would process 24 frames as: [0, 15], [4, 19], [8, 23] (0-based indexing) - weighting_scheme (`str`, defaults to `4`): - TODO(aryan) + weighting_scheme (`str`, defaults to `pyramid`): + Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting + schemes are supported currently: + - "pyramid" + Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. shuffle (`str`, defaults to `True`): - TODO(aryan): decide if this is even needed + Shuffling latents is described in Equation 9. of the paper. It is a vital in improving the video + consistency. Without shuffling, a random batch of `num_frames` latents are created. With shuffling, + only the first `context_length` latents are shuffled and repeated. """ self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length self._free_noise_context_stride = context_stride From a61ffffe425a9829088cdf4c9312a602a4275793 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 14:07:47 +0200 Subject: [PATCH 24/35] add freenoise tests to animatediff controlnet --- .../test_animatediff_controlnet.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index f46e514f9ac0..206663034bdd 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -18,6 +18,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.unets.unet_motion_model import FreeNoiseTransformerBlock from diffusers.utils import logging from diffusers.utils.testing_utils import torch_device @@ -409,6 +410,64 @@ def test_free_init_with_schedulers(self): "Enabling of FreeInit should lead to results different from the default pipeline results", ) + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device) + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + + inputs_disable_free_noise = self.get_dummy_inputs(torch_device) + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) + def test_vae_slicing(self, video_count=2): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() From d82228ea519b990c92f6c66ecf2e207c3ae259b8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 14:11:24 +0200 Subject: [PATCH 25/35] update tests --- .../animatediff/pipeline_animatediff_controlnet.py | 2 ++ tests/pipelines/animatediff/test_animatediff_controlnet.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index cdb6299e85e1..770982ef4bd3 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -30,6 +30,7 @@ from ...video_processor import VideoProcessor from ..controlnet.multicontrolnet import MultiControlNetModel from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation with ControlNet guidance. diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 206663034bdd..756ee5f2713c 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -440,19 +440,19 @@ def test_free_noise(self): pipe.set_progress_bar_config(disable=None) pipe.to(torch_device) - inputs_normal = self.get_dummy_inputs(torch_device) + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) frames_normal = pipe(**inputs_normal).frames[0] for context_length in [8, 9]: for context_stride in [4, 6]: pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) - inputs_enable_free_noise = self.get_dummy_inputs(torch_device) + inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] pipe.disable_free_noise() - inputs_disable_free_noise = self.get_dummy_inputs(torch_device) + inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() From 037ee07dcd669b42ee400cdcb93c483e04b26daf Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Jul 2024 17:42:10 +0530 Subject: [PATCH 26/35] Update src/diffusers/models/unets/unet_motion_model.py --- src/diffusers/models/unets/unet_motion_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 1a6fa76c8dae..6cc9d2fa1da2 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -231,7 +231,6 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: window_start = i window_end = min(num_frames, i + self.context_length) frame_indices.append((window_start, window_end)) - return frame_indices def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: From d19ddb483b72ca5c8c410ec79abda9666e869d09 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 11:37:29 +0200 Subject: [PATCH 27/35] add freenoise to animatediff pag --- .../pag/pipeline_pag_sd_animatediff.py | 19 ++++-- tests/pipelines/pag/test_pag_animatediff.py | 59 +++++++++++++++++++ 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index e37506a60c61..89782923bb7b 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -35,6 +35,7 @@ from ...video_processor import VideoProcessor from ..animatediff.pipeline_output import AnimateDiffPipelineOutput from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pag_utils import PAGMixin @@ -83,6 +84,7 @@ class AnimateDiffPAGPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, PAGMixin, ): r""" @@ -404,15 +406,21 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -573,6 +581,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + vae_batch_size: int = 16, pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): @@ -831,7 +840,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 8f637b991056..d664a7dbedd7 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -17,6 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.unets.unet_motion_model import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available from diffusers.utils.testing_utils import torch_device @@ -347,6 +348,64 @@ def test_free_init_with_schedulers(self): "Enabling of FreeInit should lead to results different from the default pipeline results", ) + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device) + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + + inputs_disable_free_noise = self.get_dummy_inputs(torch_device) + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", From 12cc84a8c5b3403ba0ce164e2a47148a53f058c6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 16:13:40 +0200 Subject: [PATCH 28/35] address review comments --- .../animatediff/pipeline_animatediff.py | 43 +++------ src/diffusers/pipelines/free_noise_utils.py | 88 +++++++++++++++++-- 2 files changed, 93 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 05812f8acd86..968ddd2a874d 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -505,6 +505,16 @@ def check_inputs( def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise(batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, num_channels_latents, @@ -512,43 +522,12 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - + if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) - # If FreeNoise is enabled, shuffle latents in every window as described in Equation (7) of - # [FreeNoise](https://arxiv.org/abs/2310.15169) - if self.free_noise_enabled and self._free_noise_shuffle: - for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): - # ensure window is within bounds - window_start = max(0, i - self._free_noise_context_length) - window_end = min(num_frames, window_start + self._free_noise_context_stride) - window_length = window_end - window_start - - if window_length == 0: - break - - indices = torch.LongTensor(list(range(window_start, window_end))) - shuffled_indices = indices[torch.randperm(window_length, generator=generator)] - - current_start = i - current_end = min(num_frames, current_start + window_length) - if current_end == current_start + window_length: - # batch of frames perfectly fits the window - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - else: - # handle the case where the last batch of frames does not fit perfectly with the window - prefix_length = current_end - current_start - shuffled_indices = shuffled_indices[:prefix_length] - latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 160a7ba1fcd2..8b1d323b9ed0 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -14,6 +14,8 @@ from typing import Optional, Union +import torch + from ..models.attention import BasicTransformerBlock from ..models.unets.unet_motion_model import ( CrossAttnDownBlockMotion, @@ -21,6 +23,11 @@ FreeNoiseTransformerBlock, UpBlockMotion, ) +from ..utils import logging +from ..utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name class AnimateDiffFreeNoiseMixin: @@ -91,13 +98,73 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) + + def _prepare_latents_free_noise(self, batch_size: int, num_channels_latents: int, num_frames: int, height: int, width: int, dtype: torch.dtype, device: torch.device, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + context_num_frames = self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames + + shape = ( + batch_size, + num_channels_latents, + context_num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if self._free_noise_noise_type == "random": + return latents + else: + if latents.size(2) == num_frames: + return latents + elif latents.size(2) != self._free_noise_context_length: + raise ValueError(f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}") + latents = latents.to(device) + + if self._free_noise_noise_type == "shuffle_context": + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + current_start = i + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + + elif self._free_noise_noise_type == "repeat_context": + num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length + latents = torch.cat([latents] * num_repeats, dim=2) + + latents = latents[:, :, :num_frames] + return latents + def enable_free_noise( self, context_length: Optional[int] = 16, context_stride: int = 4, weighting_scheme: str = "pyramid", - shuffle: bool = True, + noise_type: str = "shuffle_context", ) -> None: r""" Enable long video generation using FreeNoise. @@ -117,15 +184,24 @@ def enable_free_noise( schemes are supported currently: - "pyramid" Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. - shuffle (`str`, defaults to `True`): - Shuffling latents is described in Equation 9. of the paper. It is a vital in improving the video - consistency. Without shuffling, a random batch of `num_frames` latents are created. With shuffling, - only the first `context_length` latents are shuffled and repeated. + noise_type (`str`, defaults to "shuffle_context"): + TODO """ + + allowed_weighting_scheme = ["pyramid"] + allowed_noise_type = ["shuffle_context", "repeat_context", "random"] + + if context_length > self.motion_adapter.config.motion_max_seq_length: + logger.warning(f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results.") + if weighting_scheme not in allowed_weighting_scheme: + raise ValueError(f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}") + if noise_type not in allowed_noise_type: + raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}") + self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length self._free_noise_context_stride = context_stride self._free_noise_weighting_scheme = weighting_scheme - self._free_noise_shuffle = shuffle + self._free_noise_noise_type = noise_type blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: From 6f48356275874260fabc163d31a50aa18c1af4c4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 16:17:43 +0200 Subject: [PATCH 29/35] make style --- .../animatediff/pipeline_animatediff.py | 6 ++- src/diffusers/pipelines/free_noise_utils.py | 38 ++++++++++++++----- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 968ddd2a874d..4c2ba7238cde 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -507,7 +507,9 @@ def prepare_latents( ): # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) if self.free_noise_enabled: - latents = self._prepare_latents_free_noise(batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents) + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -522,7 +524,7 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - + if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index 8b1d323b9ed0..f156f6970439 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -98,15 +98,28 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) - - def _prepare_latents_free_noise(self, batch_size: int, num_channels_latents: int, num_frames: int, height: int, width: int, dtype: torch.dtype, device: torch.device, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None): + + def _prepare_latents_free_noise( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - context_num_frames = self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames + context_num_frames = ( + self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames + ) shape = ( batch_size, @@ -124,9 +137,11 @@ def _prepare_latents_free_noise(self, batch_size: int, num_channels_latents: int if latents.size(2) == num_frames: return latents elif latents.size(2) != self._free_noise_context_length: - raise ValueError(f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}") + raise ValueError( + f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}" + ) latents = latents.to(device) - + if self._free_noise_noise_type == "shuffle_context": for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): # ensure window is within bounds @@ -150,15 +165,14 @@ def _prepare_latents_free_noise(self, batch_size: int, num_channels_latents: int prefix_length = current_end - current_start shuffled_indices = shuffled_indices[:prefix_length] latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] - + elif self._free_noise_noise_type == "repeat_context": num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length latents = torch.cat([latents] * num_repeats, dim=2) - + latents = latents[:, :, :num_frames] return latents - def enable_free_noise( self, context_length: Optional[int] = 16, @@ -192,9 +206,13 @@ def enable_free_noise( allowed_noise_type = ["shuffle_context", "repeat_context", "random"] if context_length > self.motion_adapter.config.motion_max_seq_length: - logger.warning(f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results.") + logger.warning( + f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results." + ) if weighting_scheme not in allowed_weighting_scheme: - raise ValueError(f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}") + raise ValueError( + f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}" + ) if noise_type not in allowed_noise_type: raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}") From 1f0ccfdd4ba1f9deb40b1a68de73c9c3e093f62c Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 16:17:53 +0200 Subject: [PATCH 30/35] update tests --- tests/pipelines/animatediff/test_animatediff.py | 2 +- tests/pipelines/animatediff/test_animatediff_controlnet.py | 2 +- tests/pipelines/animatediff/test_animatediff_video2video.py | 2 +- tests/pipelines/pag/test_pag_animatediff.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 3db63767a544..55e6f5f4da59 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -437,7 +437,7 @@ def test_free_noise(self): for context_length in [8, 9]: for context_stride in [4, 6]: - pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + pipe.enable_free_noise(context_length, context_stride) inputs_enable_free_noise = self.get_dummy_inputs(torch_device) frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 756ee5f2713c..b72e23895191 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -445,7 +445,7 @@ def test_free_noise(self): for context_length in [8, 9]: for context_stride in [4, 6]: - pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + pipe.enable_free_noise(context_length, context_stride) inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 375e8c9cbb8e..6a05c1284606 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -466,7 +466,7 @@ def test_free_noise(self): for context_length in [8, 9]: for context_stride in [4, 6]: - pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + pipe.enable_free_noise(context_length, context_stride) inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) inputs_enable_free_noise["num_inference_steps"] = 2 diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index d664a7dbedd7..43df516cd1d8 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -383,7 +383,7 @@ def test_free_noise(self): for context_length in [8, 9]: for context_stride in [4, 6]: - pipe.enable_free_noise(context_length, context_stride, weighting_scheme="pyramid", shuffle=True) + pipe.enable_free_noise(context_length, context_stride) inputs_enable_free_noise = self.get_dummy_inputs(torch_device) frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] From 6a4aab8cb6202190086e568716e57158c8620c1e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 16:49:18 +0200 Subject: [PATCH 31/35] make fix-copies --- .../pipeline_animatediff_controlnet.py | 19 +++++++++++++------ .../pag/pipeline_pag_sd_animatediff.py | 19 +++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 770982ef4bd3..e30b29778129 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -611,10 +611,22 @@ def check_inputs( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, num_channels_latents, @@ -622,11 +634,6 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index 89782923bb7b..6a563b20bd6a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -507,10 +507,22 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, num_channels_latents, @@ -518,11 +530,6 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) From 8564dc3279c3b1917aaba914af26e8c1e5595c68 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 3 Aug 2024 21:52:19 +0200 Subject: [PATCH 32/35] fix error message --- tests/pipelines/animatediff/test_animatediff.py | 2 +- tests/pipelines/animatediff/test_animatediff_controlnet.py | 2 +- tests/pipelines/animatediff/test_animatediff_video2video.py | 2 +- tests/pipelines/pag/test_pag_animatediff.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index e6fb97c920c5..1354ac9ff1a8 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -423,7 +423,7 @@ def test_free_noise_blocks(self): for transformer_block in motion_module.transformer_blocks: self.assertFalse( isinstance(transformer_block, FreeNoiseTransformerBlock), - "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", ) def test_free_noise(self): diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index b72e23895191..7e98a3b65420 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -431,7 +431,7 @@ def test_free_noise_blocks(self): for transformer_block in motion_module.transformer_blocks: self.assertFalse( isinstance(transformer_block, FreeNoiseTransformerBlock), - "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", ) def test_free_noise(self): diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 6a05c1284606..351ab9cc97d8 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -450,7 +450,7 @@ def test_free_noise_blocks(self): for transformer_block in motion_module.transformer_blocks: self.assertFalse( isinstance(transformer_block, FreeNoiseTransformerBlock), - "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", ) def test_free_noise(self): diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 43df516cd1d8..9631b1791b38 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -369,7 +369,7 @@ def test_free_noise_blocks(self): for transformer_block in motion_module.transformer_blocks: self.assertFalse( isinstance(transformer_block, FreeNoiseTransformerBlock), - "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", ) def test_free_noise(self): From b32b1d7ff77dee164ed37e4cbe581cc703614274 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 3 Aug 2024 21:54:28 +0200 Subject: [PATCH 33/35] remove copied from comment --- src/diffusers/models/unets/unet_motion_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 845223aa85ee..86264ce95abd 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2042,7 +2042,6 @@ def fn_recursive_resnet_forward(module: torch.nn.Module, chunk_size: int): for module in self.children(): fn_recursive_resnet_forward(module, chunk_size) - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self) -> None: def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): From 045ae36fd879ce5ab3d6068a422bac436545ffe1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 3 Aug 2024 22:11:10 +0200 Subject: [PATCH 34/35] fix imports in tests --- tests/pipelines/animatediff/test_animatediff_controlnet.py | 2 +- tests/pipelines/animatediff/test_animatediff_video2video.py | 2 +- tests/pipelines/pag/test_pag_animatediff.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index 7e98a3b65420..72315bd0c965 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -18,7 +18,7 @@ UNet2DConditionModel, UNetMotionModel, ) -from diffusers.models.unets.unet_motion_model import FreeNoiseTransformerBlock +from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import logging from diffusers.utils.testing_utils import torch_device diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 351ab9cc97d8..cd33bf0891a5 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -17,7 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) -from diffusers.models.unets.unet_motion_model import FreeNoiseTransformerBlock +from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging from diffusers.utils.testing_utils import torch_device diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 9631b1791b38..71d6f234fdb4 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -17,7 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) -from diffusers.models.unets.unet_motion_model import FreeNoiseTransformerBlock +from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available from diffusers.utils.testing_utils import torch_device From 2d9aa42c912f05fe666443479eb64f296eb7ed56 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 6 Aug 2024 13:37:34 +0000 Subject: [PATCH 35/35] update --- .../models/unets/unet_motion_model.py | 107 +----------------- 1 file changed, 5 insertions(+), 102 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 86264ce95abd..7a9b7f2d5afe 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -49,24 +49,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _chunked_resnet(resnet: nn.Module, hidden_states: torch.Tensor, temb: torch.Tensor, chunk_size: int): - # "feed_forward_chunk_size" can be used to save memory - if hidden_states.shape[0] % chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {hidden_states.shape[0]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = hidden_states.shape[0] // chunk_size - output = torch.cat( - [ - resnet(hid_slice, temb_chunk) - for hid_slice, temb_chunk in zip(hidden_states.chunk(num_chunks, dim=0), temb.chunk(num_chunks, dim=0)) - ], - dim=0, - ) - return output - - @dataclass class UNetMotionOutput(BaseOutput): """ @@ -323,13 +305,6 @@ def __init__( self.gradient_checkpointing = False - # let chunk size default to None - self._chunk_size = None - - def set_chunk_resnet(self, chunk_size: Optional[int]): - # Sets chunk feed-forward - self._chunk_size = chunk_size - def forward( self, hidden_states: torch.Tensor, @@ -367,11 +342,7 @@ def custom_forward(*inputs): ) else: - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - hidden_states = _chunked_resnet(resnet, hidden_states, temb, self._chunk_size) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -518,14 +489,6 @@ def __init__( self.downsamplers = None self.gradient_checkpointing = False - # let chunk size default to None - self._chunk_size = None - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_resnet(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim def forward( self, @@ -573,11 +536,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - hidden_states = _chunked_resnet(resnet, hidden_states, temb, self._chunk_size) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -732,14 +691,6 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - # let chunk size default to None - self._chunk_size = None - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_resnet(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim def forward( self, @@ -811,11 +762,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - hidden_states = _chunked_resnet(resnet, hidden_states, temb, self._chunk_size) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, @@ -915,14 +862,6 @@ def __init__( self.gradient_checkpointing = False self.resolution_idx = resolution_idx - # let chunk size default to None - self._chunk_size = None - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_resnet(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim def forward( self, @@ -986,11 +925,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - hidden_states = _chunked_resnet(resnet, hidden_states, temb, self._chunk_size) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -1126,14 +1061,6 @@ def __init__( self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False - # let chunk size default to None - self._chunk_size = None - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_resnet(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim def forward( self, @@ -1198,11 +1125,7 @@ def custom_forward(*inputs): hidden_states, num_frames=num_frames, ) - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - hidden_states = _chunked_resnet(resnet, hidden_states, temb, self._chunk_size) - else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -2032,16 +1955,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) - def fn_recursive_resnet_forward(module: torch.nn.Module, chunk_size: int): - if hasattr(module, "set_chunk_resnet"): - module.set_chunk_resnet(chunk_size=chunk_size) - - for child in module.children(): - fn_recursive_resnet_forward(child, chunk_size) - - for module in self.children(): - fn_recursive_resnet_forward(module, chunk_size) - def disable_forward_chunking(self) -> None: def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): @@ -2053,16 +1966,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - def fn_recursive_resnet_forward(module: torch.nn.Module, chunk_size: int): - if hasattr(module, "set_chunk_resnet_forward"): - module.set_chunk_resnet_forward(chunk_size=chunk_size) - - for child in module.children(): - fn_recursive_resnet_forward(child, chunk_size) - - for module in self.children(): - fn_recursive_resnet_forward(module, None) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """