diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index bc1273aaab7d..cb577e33c670 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -32,10 +32,7 @@ from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn from .unets.unet_2d_condition import UNet2DConditionModel -from .unets.unet_3d_blocks import ( - CrossAttnDownBlockMotion, - DownBlockMotion, -) +from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -317,7 +314,6 @@ def __init__( temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - temporal_double_self_attention=False, ) elif down_block_type == "DownBlockMotion": down_block = DownBlockMotion( @@ -334,7 +330,6 @@ def __init__( add_downsample=not is_final_block, temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, - temporal_double_self_attention=False, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) else: diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 51c743a14d40..8b472a89e13d 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -27,17 +27,58 @@ TemporalConvLayer, Upsample2D, ) -from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel from ..transformers.transformer_temporal import ( TransformerSpatioTemporalModel, TransformerTemporalModel, ) +from .unet_motion_model import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class DownBlockMotion(DownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead." + deprecate("DownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead." + deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UpBlockMotion(UpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead." + deprecate("UpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead." + deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead." + deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + def get_down_block( down_block_type: str, num_layers: int, @@ -64,8 +105,6 @@ def get_down_block( ) -> Union[ "DownBlock3D", "CrossAttnDownBlock3D", - "DownBlockMotion", - "CrossAttnDownBlockMotion", "DownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", ]: @@ -105,49 +144,6 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, dropout=dropout, ) - if down_block_type == "DownBlockMotion": - return DownBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) - elif down_block_type == "CrossAttnDownBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") - return CrossAttnDownBlockMotion( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) elif down_block_type == "DownBlockSpatioTemporal": # added for SDV return DownBlockSpatioTemporal( @@ -203,8 +199,6 @@ def get_up_block( ) -> Union[ "UpBlock3D", "CrossAttnUpBlock3D", - "UpBlockMotion", - "CrossAttnUpBlockMotion", "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", ]: @@ -246,51 +240,6 @@ def get_up_block( resolution_idx=resolution_idx, dropout=dropout, ) - if up_block_type == "UpBlockMotion": - return UpBlockMotion( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) - elif up_block_type == "CrossAttnUpBlockMotion": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") - return CrossAttnUpBlockMotion( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - resolution_idx=resolution_idx, - temporal_num_attention_heads=temporal_num_attention_heads, - temporal_max_seq_length=temporal_max_seq_length, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - dropout=dropout, - ) elif up_block_type == "UpBlockSpatioTemporal": # added for SDV return UpBlockSpatioTemporal( @@ -947,924 +896,6 @@ def forward( return hidden_states -class DownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - temporal_num_attention_heads: Union[int, Tuple[int]] = 1, - temporal_cross_attention_dim: Optional[int] = None, - temporal_max_seq_length: int = 32, - temporal_double_self_attention: bool = True, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - motion_modules = [] - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" - ) - - # support for variable number of attention head per temporal layers - if isinstance(temporal_num_attention_heads, int): - temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers - elif len(temporal_num_attention_heads) != num_layers: - raise ValueError( - f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" - ) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads[i], - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads[i], - double_self_attention=temporal_double_self_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - num_frames: int = 1, - *args, - **kwargs, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - output_states = () - - blocks = zip(self.resnets, self.motion_modules) - for resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - downsample_padding: int = 1, - add_downsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - temporal_double_self_attention: bool = True, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - double_self_attention=temporal_double_self_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - encoder_attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - additional_residuals: Optional[torch.Tensor] = None, - ): - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - output_states = () - - blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) - for i, (resnet, attn, motion_module) in enumerate(blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - - # apply additional residuals to the output of the last pair of resnet and attention blocks - if i == len(blocks) - 1 and additional_residuals is not None: - hidden_states = hidden_states + additional_residuals - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnUpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" - ) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: Tuple[torch.Tensor, ...], - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.attentions, self.motion_modules) - for resnet, attn, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpBlockMotion(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temporal_cross_attention_dim: Optional[int] = None, - temporal_num_attention_heads: int = 8, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - resnets = [] - motion_modules = [] - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" - ) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - in_channels=out_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - activation_fn="geglu", - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: Tuple[torch.Tensor, ...], - temb: Optional[torch.Tensor] = None, - upsample_size=None, - num_frames: int = 1, - *args, - **kwargs, - ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - blocks = zip(self.resnets, self.motion_modules) - - for resnet, motion_module in blocks: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UNetMidBlockCrossAttnMotion(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - output_scale_factor: float = 1.0, - cross_attention_dim: int = 1280, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - temporal_num_attention_heads: int = 1, - temporal_cross_attention_dim: Optional[int] = None, - temporal_max_seq_length: int = 32, - temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, - ): - super().__init__() - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = (transformer_layers_per_block,) * num_layers - elif len(transformer_layers_per_block) != num_layers: - raise ValueError( - f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." - ) - - # support for variable transformer layers per temporal block - if isinstance(temporal_transformer_layers_per_block, int): - temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers - elif len(temporal_transformer_layers_per_block) != num_layers: - raise ValueError( - f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." - ) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - motion_modules = [] - - for i in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - motion_modules.append( - TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, - attention_head_dim=in_channels // temporal_num_attention_heads, - in_channels=in_channels, - num_layers=temporal_transformer_layers_per_block[i], - norm_num_groups=resnet_groups, - cross_attention_dim=temporal_cross_attention_dim, - attention_bias=False, - positional_embeddings="sinusoidal", - num_positional_embeddings=temporal_max_seq_length, - activation_fn="geglu", - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - num_frames: int = 1, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - hidden_states = self.resnets[0](hidden_states, temb) - - blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) - for attn, resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(motion_module), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = motion_module( - hidden_states, - num_frames=num_frames, - )[0] - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - class MidBlockTemporalDecoder(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index c8ea0ecc3feb..e96867bc3ed0 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch @@ -20,7 +22,9 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import logging +from ...utils import BaseOutput, deprecate, is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..attention import BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -35,24 +39,1094 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin -from ..transformers.transformer_temporal import TransformerTemporalModel +from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel -from .unet_3d_blocks import ( - CrossAttnDownBlockMotion, - CrossAttnUpBlockMotion, - DownBlockMotion, - UNetMidBlockCrossAttnMotion, - UpBlockMotion, - get_down_block, - get_up_block, -) -from .unet_3d_condition import UNet3DConditionOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@dataclass +class UNetMotionOutput(BaseOutput): + """ + The output of [`UNetMotionOutput`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor + + +class AnimateDiffTransformer3D(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + """ + The [`AnimateDiffTransformer3D`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Returns: + torch.Tensor: + The output tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + return output + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: Union[int, Tuple[int]] = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" + ) + + # support for variable number of attention head per temporal layers + if isinstance(temporal_num_attention_heads, int): + temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers + elif len(temporal_num_attention_heads) != num_layers: + raise ValueError( + f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads[i], + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads[i], + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + num_frames: int = 1, + *args, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[torch.Tensor] = None, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + upsample_size=None, + num_frames: int = 1, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + class MotionModules(nn.Module): def __init__( self, @@ -79,7 +1153,7 @@ def __init__( for i in range(layers_per_block): self.motion_modules.append( - TransformerTemporalModel( + AnimateDiffTransformer3D( in_channels=in_channels, num_layers=transformer_layers_per_block[i], norm_num_groups=norm_num_groups, @@ -394,26 +1468,45 @@ def __init__( output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - use_linear_projection=use_linear_projection, - dual_cross_attention=False, - temporal_num_attention_heads=motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=transformer_layers_per_block[i], - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], - ) + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + downsample_padding=downsample_padding, + add_downsample=not is_final_block, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + downsample_padding=downsample_padding, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) + self.down_blocks.append(down_block) # mid @@ -488,27 +1581,47 @@ def __init__( else: add_upsample = False - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=False, - resolution_idx=i, - use_linear_projection=use_linear_projection, - temporal_num_attention_heads=reversed_motion_num_attention_heads[i], - temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=reverse_transformer_layers_per_block[i], - temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], - ) + if up_block_type == "CrossAttnUpBlockMotion": + up_block = CrossAttnUpBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reverse_transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=reversed_num_attention_heads[i], + cross_attention_dim=reversed_cross_attention_dim[i], + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + elif up_block_type == "UpBlockMotion": + up_block = UpBlockMotion( + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=add_upsample, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`" + ) + self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -958,7 +2071,7 @@ def forward( down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, - ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: + ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]: r""" The [`UNetMotionModel`] forward method. @@ -984,12 +2097,12 @@ def forward( mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain tuple. Returns: - [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + [`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. @@ -1173,4 +2286,4 @@ def forward( if not return_dict: return (sample,) - return UNet3DConditionOutput(sample=sample) + return UNetMotionOutput(sample=sample)