From ba504ade78ddf4e857a4f2a4b7d929c898bf628b Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 00:28:20 +0200 Subject: [PATCH 01/14] animatediff specific transformer model --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../animatediff_transformer_3d.py | 187 ++++++++++++++++++ src/diffusers/models/unets/unet_3d_blocks.py | 11 +- .../models/unets/unet_motion_model.py | 4 +- 6 files changed, 200 insertions(+), 7 deletions(-) create mode 100644 src/diffusers/models/transformers/animatediff_transformer_3d.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6a6607cc376f..77cdebf0eab5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -75,6 +75,7 @@ else: _import_structure["models"].extend( [ + "AnimateDiffTransformer3DModel", "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", @@ -509,6 +510,7 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( + AnimateDiffTransformer3DModel, AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, AutoencoderKL, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 39dc149ff6d1..6cb74fd2d207 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -38,6 +38,7 @@ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["transformers.animatediff_transformer_3d"] = ["AnimateDiffTransformer3DModel"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] @@ -85,6 +86,7 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + AnimateDiffTransformer3DModel, AuraFlowTransformer2DModel, DiTTransformer2DModel, DualTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ae5103160790..c14705449e92 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -2,6 +2,7 @@ if is_torch_available(): + from .animatediff_transformer_3d import AnimateDiffTransformer3DModel from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel diff --git a/src/diffusers/models/transformers/animatediff_transformer_3d.py b/src/diffusers/models/transformers/animatediff_transformer_3d.py new file mode 100644 index 000000000000..422430919b9f --- /dev/null +++ b/src/diffusers/models/transformers/animatediff_transformer_3d.py @@ -0,0 +1,187 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock +from ..modeling_utils import ModelMixin + + +@dataclass +class AnimateDiffTransformer3DModelOutput(BaseOutput): + """ + The output of [`AnimateDiffTransformer3DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size * num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.Tensor + + +# TODO(aryan): Should we have the (ModelMixin, ConfigMixin) here, or is this as per what you meant in our chat? +class AnimateDiffTransformer3DModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + # TODO(aryan): Since we removed ConfigMixin, this isn't required anymore too unless I interpreted our message incorrectly + # @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: torch.LongTensor = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> AnimateDiffTransformer3DModelOutput: + """ + The [`AnimateDiffTransformer3DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] + instead of a plain tuple. + + Returns: + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return AnimateDiffTransformer3DModelOutput(sample=output) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 14f2df0b582c..16dc3be59bcb 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -27,6 +27,7 @@ TemporalConvLayer, Upsample2D, ) +from ..transformers.animatediff_transformer_3d import AnimateDiffTransformer3DModel from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel from ..transformers.transformer_temporal import ( @@ -1005,7 +1006,7 @@ def __init__( ) ) motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3DModel( num_attention_heads=temporal_num_attention_heads[i], in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -1188,7 +1189,7 @@ def __init__( ) motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3DModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -1398,7 +1399,7 @@ def __init__( ) ) motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3DModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -1569,7 +1570,7 @@ def __init__( ) motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3DModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -1773,7 +1774,7 @@ def __init__( ) ) motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3DModel( num_attention_heads=temporal_num_attention_heads, attention_head_dim=in_channels // temporal_num_attention_heads, in_channels=in_channels, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 196f947d599b..3551bdb3443e 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -35,7 +35,7 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin -from ..transformers.transformer_temporal import TransformerTemporalModel +from ..transformers.animatediff_transformer_3d import AnimateDiffTransformer3DModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel from .unet_3d_blocks import ( @@ -79,7 +79,7 @@ def __init__( for i in range(layers_per_block): self.motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3DModel( in_channels=in_channels, num_layers=transformer_layers_per_block[i], norm_num_groups=norm_num_groups, From d214117d3a3b402b15a22a3d969cd77cc20130be Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 00:29:18 +0200 Subject: [PATCH 02/14] make style --- .../models/transformers/animatediff_transformer_3d.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/animatediff_transformer_3d.py b/src/diffusers/models/transformers/animatediff_transformer_3d.py index 422430919b9f..a5122305e6ae 100644 --- a/src/diffusers/models/transformers/animatediff_transformer_3d.py +++ b/src/diffusers/models/transformers/animatediff_transformer_3d.py @@ -4,10 +4,8 @@ import torch from torch import nn -from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ..attention import BasicTransformerBlock -from ..modeling_utils import ModelMixin @dataclass @@ -136,14 +134,15 @@ def forward( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] - instead of a plain tuple. + Whether or not to return a + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] instead of a + plain tuple. Returns: [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] or `tuple`: If `return_dict` is True, an - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape From 11746dce8c0d928c28d14a544cb4ad10481c757d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 00:29:36 +0200 Subject: [PATCH 03/14] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5df0d6d28f53..80f77ca4256b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AnimateDiffTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] From 20919673fa07f38f87b1b8d0abe38e36b38c234b Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 10:26:02 +0200 Subject: [PATCH 04/14] move blocks to unet motion model --- src/diffusers/__init__.py | 2 - src/diffusers/models/__init__.py | 2 - src/diffusers/models/transformers/__init__.py | 1 - .../animatediff_transformer_3d.py | 186 --- src/diffusers/models/unets/unet_3d_blocks.py | 1007 -------------- .../models/unets/unet_motion_model.py | 1227 ++++++++++++++++- 6 files changed, 1175 insertions(+), 1250 deletions(-) delete mode 100644 src/diffusers/models/transformers/animatediff_transformer_3d.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 77cdebf0eab5..6a6607cc376f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -75,7 +75,6 @@ else: _import_structure["models"].extend( [ - "AnimateDiffTransformer3DModel", "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", @@ -510,7 +509,6 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( - AnimateDiffTransformer3DModel, AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, AutoencoderKL, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 6cb74fd2d207..39dc149ff6d1 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -38,7 +38,6 @@ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] - _import_structure["transformers.animatediff_transformer_3d"] = ["AnimateDiffTransformer3DModel"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] @@ -86,7 +85,6 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( - AnimateDiffTransformer3DModel, AuraFlowTransformer2DModel, DiTTransformer2DModel, DualTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index c14705449e92..ae5103160790 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -2,7 +2,6 @@ if is_torch_available(): - from .animatediff_transformer_3d import AnimateDiffTransformer3DModel from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel diff --git a/src/diffusers/models/transformers/animatediff_transformer_3d.py b/src/diffusers/models/transformers/animatediff_transformer_3d.py deleted file mode 100644 index a5122305e6ae..000000000000 --- a/src/diffusers/models/transformers/animatediff_transformer_3d.py +++ /dev/null @@ -1,186 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Optional - -import torch -from torch import nn - -from ...utils import BaseOutput -from ..attention import BasicTransformerBlock - - -@dataclass -class AnimateDiffTransformer3DModelOutput(BaseOutput): - """ - The output of [`AnimateDiffTransformer3DModel`]. - - Args: - sample (`torch.Tensor` of shape `(batch_size * num_frames, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. - """ - - sample: torch.Tensor - - -# TODO(aryan): Should we have the (ModelMixin, ConfigMixin) here, or is this as per what you meant in our chat? -class AnimateDiffTransformer3DModel(nn.Module): - """ - A Transformer model for video-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlock` attention should contain a bias parameter. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - activation_fn (`str`, *optional*, defaults to `"geglu"`): - Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported - activation functions. - norm_elementwise_affine (`bool`, *optional*): - Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. - double_self_attention (`bool`, *optional*): - Configure if each `TransformerBlock` should contain two self-attention layers. - positional_embeddings: (`str`, *optional*): - The type of positional embeddings to apply to the sequence input before passing use. - num_positional_embeddings: (`int`, *optional*): - The maximum length of the sequence over which to apply positional embeddings. - """ - - # TODO(aryan): Since we removed ConfigMixin, this isn't required anymore too unless I interpreted our message incorrectly - # @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? - activation_fn: str = "geglu", - norm_elementwise_affine: bool = True, - double_self_attention: bool = True, - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Linear(in_channels, inner_dim) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - double_self_attention=double_self_attention, - norm_elementwise_affine=norm_elementwise_affine, - positional_embeddings=positional_embeddings, - num_positional_embeddings=num_positional_embeddings, - ) - for _ in range(num_layers) - ] - ) - - self.proj_out = nn.Linear(inner_dim, in_channels) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.LongTensor] = None, - timestep: Optional[torch.LongTensor] = None, - class_labels: torch.LongTensor = None, - num_frames: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> AnimateDiffTransformer3DModelOutput: - """ - The [`AnimateDiffTransformer3DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): - Input hidden_states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - num_frames (`int`, *optional*, defaults to 1): - The number of frames to be processed per batch. This is used to reshape the hidden states. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] instead of a - plain tuple. - - Returns: - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] or `tuple`: - If `return_dict` is True, an - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] is returned, - otherwise a `tuple` where the first element is the sample tensor. - """ - # 1. Input - batch_frames, channel, height, width = hidden_states.shape - batch_size = batch_frames // num_frames - - residual = hidden_states - - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - - hidden_states = self.proj_in(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states[None, None, :] - .reshape(batch_size, height, width, num_frames, channel) - .permute(0, 3, 4, 1, 2) - .contiguous() - ) - hidden_states = hidden_states.reshape(batch_frames, channel, height, width) - - output = hidden_states + residual - - if not return_dict: - return (output,) - - return AnimateDiffTransformer3DModelOutput(sample=output) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 16dc3be59bcb..c58026b4cf86 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -27,7 +27,6 @@ TemporalConvLayer, Upsample2D, ) -from ..transformers.animatediff_transformer_3d import AnimateDiffTransformer3DModel from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel from ..transformers.transformer_temporal import ( @@ -65,8 +64,6 @@ def get_down_block( ) -> Union[ "DownBlock3D", "CrossAttnDownBlock3D", - "DownBlockMotion", - "CrossAttnDownBlockMotion", "DownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", ]: @@ -106,49 +103,6 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, dropout=dropout, ) - if down_block_type == "DownBlockMotion": - return DownBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) - elif down_block_type == "CrossAttnDownBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") - return CrossAttnDownBlockMotion( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) elif down_block_type == "DownBlockSpatioTemporal": # added for SDV return DownBlockSpatioTemporal( @@ -204,8 +158,6 @@ def get_up_block( ) -> Union[ "UpBlock3D", "CrossAttnUpBlock3D", - "UpBlockMotion", - "CrossAttnUpBlockMotion", "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", ]: @@ -247,51 +199,6 @@ def get_up_block( resolution_idx=resolution_idx, dropout=dropout, ) - if up_block_type == "UpBlockMotion": - return UpBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) - elif up_block_type == "CrossAttnUpBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") - return CrossAttnUpBlockMotion( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) elif up_block_type == "UpBlockSpatioTemporal": # added for SDV return UpBlockSpatioTemporal( @@ -948,920 +855,6 @@ def forward( return hidden_states -class DownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - temporal_num_attention_heads: Union[int, Tuple[int]] = 1, - temporal_cross_attention_dim: Optional[int] = None, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - motion_modules = [] - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" - ) - - # support for variable number of attention head per temporal layers - if isinstance(temporal_num_attention_heads, int): - temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers - elif len(temporal_num_attention_heads) != num_layers: - raise ValueError( - f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" - ) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - AnimateDiffTransformer3DModel( - num_attention_heads=temporal_num_attention_heads[i], - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads[i], - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - num_frames: int = 1, - *args, - **kwargs, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - output_states = () - - blocks = zip(self.resnets, self.motion_modules) - for resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - downsample_padding: int = 1, - add_downsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - - motion_modules.append( - AnimateDiffTransformer3DModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - encoder_attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - additional_residuals: Optional[torch.Tensor] = None, - ): - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - output_states = () - - blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) - for i, (resnet, attn, motion_module) in enumerate(blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - - # apply additional residuals to the output of the last pair of resnet and attention blocks - if i == len(blocks) - 1 and additional_residuals is not None: - hidden_states = hidden_states + additional_residuals - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnUpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" - ) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - motion_modules.append( - AnimateDiffTransformer3DModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: Tuple[torch.Tensor, ...], - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.attentions, self.motion_modules) - for resnet, attn, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - motion_modules = [] - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - motion_modules.append( - AnimateDiffTransformer3DModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: Tuple[torch.Tensor, ...], - temb: Optional[torch.Tensor] = None, - upsample_size=None, - num_frames: int = 1, - *args, - **kwargs, - ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.motion_modules) - - for resnet, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UNetMidBlockCrossAttnMotion(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - output_scale_factor: float = 1.0, - cross_attention_dim: int = 1280, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_num_attention_heads: int = 1, - temporal_cross_attention_dim: Optional[int] = None, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." - ) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - motion_modules = [] - - for i in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - AnimateDiffTransformer3DModel( - num_attention_heads=temporal_num_attention_heads, - attention_head_dim=in_channels // temporal_num_attention_heads, - in_channels=in_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - activation_fn="geglu", - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - hidden_states = self.resnets[0](hidden_states, temb) - - blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) - for attn, resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - class MidBlockTemporalDecoder(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 3551bdb3443e..fd0d1d9c99c0 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch @@ -20,7 +22,9 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin -from ...utils import logging +from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..attention import BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -35,24 +39,1106 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin -from ..transformers.animatediff_transformer_3d import AnimateDiffTransformer3DModel +from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel -from .unet_3d_blocks import ( - CrossAttnDownBlockMotion, - CrossAttnUpBlockMotion, - DownBlockMotion, - UNetMidBlockCrossAttnMotion, - UpBlockMotion, - get_down_block, - get_up_block, -) from .unet_3d_condition import UNet3DConditionOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@dataclass +class AnimateDiffTransformer3DModelOutput(BaseOutput): + """ + The output of [`AnimateDiffTransformer3DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size * num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.Tensor + + +class AnimateDiffTransformer3DModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> AnimateDiffTransformer3DModelOutput: + """ + The [`AnimateDiffTransformer3DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] instead of a + plain tuple. + + Returns: + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return AnimateDiffTransformer3DModelOutput(sample=output) + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: Union[int, Tuple[int]] = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" + ) + + # support for variable number of attention head per temporal layers + if isinstance(temporal_num_attention_heads, int): + temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers + elif len(temporal_num_attention_heads) != num_layers: + raise ValueError( + f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3DModel( + num_attention_heads=temporal_num_attention_heads[i], + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads[i], + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + num_frames: int = 1, + *args, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3DModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[torch.Tensor] = None, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + AnimateDiffTransformer3DModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3DModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size=None, + num_frames: int = 1, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3DModel( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + class MotionModules(nn.Module): def __init__( self, @@ -394,26 +1480,45 @@ def __init__( output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - use_linear_projection=use_linear_projection, - dual_cross_attention=False, - temporal_num_attention_heads=motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=transformer_layers_per_block[i], - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - ) + # TODO(aryan): Can we reduce LOC here by creating a dictionary of common arguments and then passing **kwargs? + # Many params are repeated here. + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + downsample_padding=downsample_padding, + add_downsample=not is_final_block, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + downsample_padding=downsample_padding, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`") + self.down_blocks.append(down_block) # mid @@ -487,28 +1592,46 @@ def __init__( self.num_upsamplers += 1 else: add_upsample = False + + if up_block_type == "CrossAttnUpBlockMotion": + up_block = CrossAttnUpBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reverse_transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=reversed_num_attention_heads[i], + cross_attention_dim=reversed_cross_attention_dim[i], + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + elif up_block_type == "UpBlockMotion": + up_block = UpBlockMotion( + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=add_upsample, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError("Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`") - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=False, - resolution_idx=i, - use_linear_projection=use_linear_projection, - temporal_num_attention_heads=reversed_motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=reverse_transformer_layers_per_block[i], - temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], - ) self.up_blocks.append(up_block) prev_output_channel = output_channel From fd83a546d3ddde2e8f930cc060d56a22ef230c4c Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 10:26:33 +0200 Subject: [PATCH 05/14] make style --- src/diffusers/models/unets/unet_3d_blocks.py | 3 +-- src/diffusers/models/unets/unet_motion_model.py | 10 +++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index c58026b4cf86..eb8cc3ed5e3a 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from ...utils import deprecate, is_torch_version, logging +from ...utils import is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( @@ -27,7 +27,6 @@ TemporalConvLayer, Upsample2D, ) -from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel from ..transformers.transformer_temporal import ( TransformerSpatioTemporalModel, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index fd0d1d9c99c0..42f5cc87359f 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1517,7 +1517,9 @@ def __init__( temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) else: - raise ValueError("Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`") + raise ValueError( + "Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) self.down_blocks.append(down_block) @@ -1592,7 +1594,7 @@ def __init__( self.num_upsamplers += 1 else: add_upsample = False - + if up_block_type == "CrossAttnUpBlockMotion": up_block = CrossAttnUpBlockMotion( in_channels=input_channel, @@ -1630,7 +1632,9 @@ def __init__( temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], ) else: - raise ValueError("Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`") + raise ValueError( + "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`" + ) self.up_blocks.append(up_block) prev_output_channel = output_channel From 60e21a0855ba6d824a29d0565476bef9ba4d1604 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 10:29:32 +0200 Subject: [PATCH 06/14] remove dummy object --- src/diffusers/utils/dummy_pt_objects.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 80f77ca4256b..5df0d6d28f53 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,21 +2,6 @@ from ..utils import DummyObject, requires_backends -class AnimateDiffTransformer3DModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] From a175bdbae98152839f5d8b0339fdaeb44089343c Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 13:05:36 +0200 Subject: [PATCH 07/14] fix incorrectly passed param causing test failures --- src/diffusers/models/unets/unet_motion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 42f5cc87359f..7e27bad565e1 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1512,7 +1512,7 @@ def __init__( resnet_groups=norm_num_groups, add_downsample=not is_final_block, downsample_padding=downsample_padding, - temporal_num_attention_heads=motion_num_attention_heads, + temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) From 097e9c16368fdb1b0fce2cc49682b0e033d3305d Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 29 Jul 2024 10:07:57 +0200 Subject: [PATCH 08/14] rename model and output class --- .../models/unets/unet_motion_model.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7e27bad565e1..66a78d1f9c68 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -51,9 +51,9 @@ @dataclass -class AnimateDiffTransformer3DModelOutput(BaseOutput): +class AnimateDiffTransformer3DOutput(BaseOutput): """ - The output of [`AnimateDiffTransformer3DModel`]. + The output of [`AnimateDiffTransformer3D`]. Args: sample (`torch.Tensor` of shape `(batch_size * num_frames, num_channels, height, width)`): @@ -63,7 +63,7 @@ class AnimateDiffTransformer3DModelOutput(BaseOutput): sample: torch.Tensor -class AnimateDiffTransformer3DModel(nn.Module): +class AnimateDiffTransformer3D(nn.Module): """ A Transformer model for video-like data. @@ -151,9 +151,9 @@ def forward( num_frames: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - ) -> AnimateDiffTransformer3DModelOutput: + ) -> AnimateDiffTransformer3DOutput: """ - The [`AnimateDiffTransformer3DModel`] forward method. + The [`AnimateDiffTransformer3D`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): @@ -174,13 +174,13 @@ def forward( [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] instead of a - plain tuple. + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DOutput`] instead of a plain + tuple. Returns: - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] or `tuple`: + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DOutput`] or `tuple`: If `return_dict` is True, an - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DModelOutput`] is returned, + [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input @@ -222,7 +222,7 @@ def forward( if not return_dict: return (output,) - return AnimateDiffTransformer3DModelOutput(sample=output) + return AnimateDiffTransformer3DOutput(sample=output) class DownBlockMotion(nn.Module): @@ -283,7 +283,7 @@ def __init__( ) ) motion_modules.append( - AnimateDiffTransformer3DModel( + AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads[i], in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -466,7 +466,7 @@ def __init__( ) motion_modules.append( - AnimateDiffTransformer3DModel( + AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -676,7 +676,7 @@ def __init__( ) ) motion_modules.append( - AnimateDiffTransformer3DModel( + AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -847,7 +847,7 @@ def __init__( ) motion_modules.append( - AnimateDiffTransformer3DModel( + AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], @@ -1051,7 +1051,7 @@ def __init__( ) ) motion_modules.append( - AnimateDiffTransformer3DModel( + AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, attention_head_dim=in_channels // temporal_num_attention_heads, in_channels=in_channels, @@ -1165,7 +1165,7 @@ def __init__( for i in range(layers_per_block): self.motion_modules.append( - AnimateDiffTransformer3DModel( + AnimateDiffTransformer3D( in_channels=in_channels, num_layers=transformer_layers_per_block[i], norm_num_groups=norm_num_groups, From b4ff62dd20b8fad6186bc090108137909facd9cd Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 29 Jul 2024 10:08:56 +0200 Subject: [PATCH 09/14] fix sparsectrl imports --- src/diffusers/models/controlnet_sparsectrl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index bc1273aaab7d..ccdfc01e7b2e 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -32,10 +32,7 @@ from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn from .unets.unet_2d_condition import UNet2DConditionModel -from .unets.unet_3d_blocks import ( - CrossAttnDownBlockMotion, - DownBlockMotion, -) +from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 606afa6798b6e80fb1ac9a4a549a83b28cef769a Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 29 Jul 2024 10:12:33 +0200 Subject: [PATCH 10/14] remove todo comments --- src/diffusers/models/unets/unet_motion_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 66a78d1f9c68..f42d5a4f9629 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -97,13 +97,13 @@ def __init__( num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, - out_channels: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? + out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, - sample_size: Optional[int] = None, # TODO(aryan): seems like unused parameter, do we remove? + sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, @@ -1480,8 +1480,6 @@ def __init__( output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - # TODO(aryan): Can we reduce LOC here by creating a dictionary of common arguments and then passing **kwargs? - # Many params are repeated here. if down_block_type == "CrossAttnDownBlockMotion": down_block = CrossAttnDownBlockMotion( in_channels=input_channel, From b37d4661edc94d9a14778dcf5083cbeb23da125f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 29 Jul 2024 10:18:21 +0200 Subject: [PATCH 11/14] remove temporal double self attn param from controlnet sparsectrl --- src/diffusers/models/controlnet_sparsectrl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index ccdfc01e7b2e..cb577e33c670 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -314,7 +314,6 @@ def __init__( temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - temporal_double_self_attention=False, ) elif down_block_type == "DownBlockMotion": down_block = DownBlockMotion( @@ -331,7 +330,6 @@ def __init__( add_downsample=not is_final_block, temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, - temporal_double_self_attention=False, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) else: From 087a340388a04822139827c4b510f6b468e4883e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 11:05:31 +0200 Subject: [PATCH 12/14] add deprecated versions of blocks --- src/diffusers/models/unets/unet_3d_blocks.py | 44 ++++++++++++++++++- .../models/unets/unet_motion_model.py | 24 +++++++--- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index eb8cc3ed5e3a..8b472a89e13d 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,7 +17,7 @@ import torch from torch import nn -from ...utils import is_torch_version, logging +from ...utils import deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( @@ -32,11 +32,53 @@ TransformerSpatioTemporalModel, TransformerTemporalModel, ) +from .unet_motion_model import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class DownBlockMotion(DownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead." + deprecate("DownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead." + deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UpBlockMotion(UpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead." + deprecate("UpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead." + deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead." + deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + def get_down_block( down_block_type: str, num_layers: int, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5de36dbeab15..08ad85f6b595 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -44,7 +44,6 @@ from ..transformers.transformer_2d import Transformer2DModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel -from .unet_3d_condition import UNet3DConditionOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -63,6 +62,19 @@ class AnimateDiffTransformer3DOutput(BaseOutput): sample: torch.Tensor +@dataclass +class UNetMotionOutput(BaseOutput): + """ + The output of [`UNetMotionOutput`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor + + class AnimateDiffTransformer3D(nn.Module): """ A Transformer model for video-like data. @@ -2083,7 +2095,7 @@ def forward( down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, - ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: + ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]: r""" The [`UNetMotionModel`] forward method. @@ -2109,12 +2121,12 @@ def forward( mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain tuple. Returns: - [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + [`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. @@ -2298,4 +2310,4 @@ def forward( if not return_dict: return (sample,) - return UNet3DConditionOutput(sample=sample) + return UNetMotionOutput(sample=sample) From 60ba7135042b907159274cc0212960c223bfe798 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 16:59:49 +0200 Subject: [PATCH 13/14] apply suggestions from review --- .../models/unets/unet_motion_model.py | 32 +++---------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 08ad85f6b595..c87bb8161bd5 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -49,19 +49,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class AnimateDiffTransformer3DOutput(BaseOutput): - """ - The output of [`AnimateDiffTransformer3D`]. - - Args: - sample (`torch.Tensor` of shape `(batch_size * num_frames, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. - """ - - sample: torch.Tensor - - @dataclass class UNetMotionOutput(BaseOutput): """ @@ -162,8 +149,7 @@ def forward( class_labels: Optional[torch.LongTensor] = None, num_frames: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> AnimateDiffTransformer3DOutput: + ) -> torch.Tensor: """ The [`AnimateDiffTransformer3D`] forward method. @@ -184,16 +170,10 @@ def forward( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DOutput`] instead of a plain - tuple. Returns: - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DOutput`] or `tuple`: - If `return_dict` is True, an - [`~models.transformers.animatediff_transformer_3d.AnimateDiffTransformer3DOutput`] is returned, - otherwise a `tuple` where the first element is the sample tensor. + torch.Tensor: + The output tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape @@ -230,11 +210,7 @@ def forward( hidden_states = hidden_states.reshape(batch_frames, channel, height, width) output = hidden_states + residual - - if not return_dict: - return (output,) - - return AnimateDiffTransformer3DOutput(sample=output) + return output class DownBlockMotion(nn.Module): From 0da444685104706388d6aa343993297b8df7baeb Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 3 Aug 2024 06:55:44 +0000 Subject: [PATCH 14/14] update --- src/diffusers/models/unets/unet_motion_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index c87bb8161bd5..e96867bc3ed0 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -343,7 +343,7 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + hidden_states = motion_module(hidden_states, num_frames=num_frames) output_states = output_states + (hidden_states,) @@ -547,7 +547,7 @@ def custom_forward(*inputs): hidden_states = motion_module( hidden_states, num_frames=num_frames, - )[0] + ) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: @@ -772,7 +772,7 @@ def custom_forward(*inputs): hidden_states = motion_module( hidden_states, num_frames=num_frames, - )[0] + ) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -924,7 +924,7 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1121,7 +1121,7 @@ def custom_forward(*inputs): hidden_states = motion_module( hidden_states, num_frames=num_frames, - )[0] + ) hidden_states = resnet(hidden_states, temb) return hidden_states