diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 88da548bd597..bf6bd39953e7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -200,6 +200,8 @@ title: Tiny AutoEncoder - local: api/models/transformer2d title: Transformer2D + - local: api/models/transformer3d + title: Transformer3D - local: api/models/transformer_temporal title: Transformer Temporal - local: api/models/prior_transformer @@ -316,6 +318,8 @@ title: Text-to-video - local: api/pipelines/text_to_video_zero title: Text2Video-Zero + - local: api/pipelines/tune_a_video + title: Tune-A-Video - local: api/pipelines/unclip title: UnCLIP - local: api/pipelines/latent_diffusion_uncond diff --git a/docs/source/en/api/models/transformer3d.md b/docs/source/en/api/models/transformer3d.md new file mode 100644 index 000000000000..4f73573006da --- /dev/null +++ b/docs/source/en/api/models/transformer3d.md @@ -0,0 +1,11 @@ +# Transformer3D + +The Transformer2D model extended for video-like data. + +## Transformer3DModel + +[[autodoc]] Transformer3DModel + +## Transformer3DModelOutput + +[[autodoc]] models.transformer_3d.Transformer3DModelOutput diff --git a/docs/source/en/api/pipelines/tune_a_video.mdx b/docs/source/en/api/pipelines/tune_a_video.mdx new file mode 100644 index 000000000000..202051251f15 --- /dev/null +++ b/docs/source/en/api/pipelines/tune_a_video.mdx @@ -0,0 +1,122 @@ + + +# Tune-A-Video + +## Overview + +[Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation](https://arxiv.org/abs/2212.11565) by Jay Zhangjie Wu, Yixiao Ge, Xintao Wang, Stan Weixian Lei, Yuchao Gu, Yufei Shi, Wynne Hsu, Ying Shan, Xiaohu Qie, Mike Zheng Shou +The abstract of the paper is the following: + +*To replicate the success of text-to-image (T2I) generation, recent works employ large-scale video datasets to train a text-to-video (T2V) generator. Despite their promising results, such paradigm is computationally expensive. In this work, we propose a new T2V generation setting—One-Shot Video Tuning, where only one text-video pair is presented. Our model is built on state-of-the-art T2I diffusion models pre-trained on massive image data. We make two key observations: 1) T2I models can generate still images that represent verb terms; 2) extending T2I models to generate multiple images concurrently exhibits surprisingly good content consistency. To further learn continuous motion, we introduce Tune-A-Video, which involves a tailored spatio-temporal attention mechanism and an efficient one-shot tuning strategy. At inference, we employ DDIM inversion to provide structure guidance for sampling. Extensive qualitative and numerical experiments demonstrate the remarkable ability of our method across various applications.* + +Resources: + +* [GitHub repository](https://github.com/showlab/Tune-A-Video) +* [🤗 Spaces](https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-Training-UI) + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [TuneAVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py) | *Text-to-Video Generation* | [🤗 Spaces](https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-inference) + +## Usage example + +### Loading with a pre-existing Text2Image checkpoint +```python +import torch +from diffusers import TuneAVideoPipeline, DDIMScheduler, UNet3DConditionModel +from diffusers.utils import export_to_video +from PIL import Image + +# Use any pretrained Text2Image checkpoint based on stable diffusion +pretrained_model_path = "nitrosocke/mo-di-diffusion" +unet = UNet3DConditionModel.from_pretrained( + "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", subfolder="unet", torch_dtype=torch.float16 +).to("cuda") + +pipe = TuneAVideoPipeline.from_pretrained(pretrained_model_path, unet=unet, torch_dtype=torch.float16).to("cuda") + +prompt = "A princess playing a guitar, modern disney style" +generator = torch.Generator(device="cuda").manual_seed(42) + +video_frames = pipe(prompt, video_length=3, generator=generator, num_inference_steps=50, output_type="np").frames + +# Saving to gif. +pil_frames = [Image.fromarray(frame) for frame in video_frames] +duration = len(pil_frames) / 8 +pil_frames[0].save( + "animation.gif", + save_all=True, + append_images=pil_frames[1:], # append rest of the images + duration=duration * 1000, # in milliseconds + loop=0, +) + +# Saving to video +video_path = export_to_video(video_frames) +``` + +### Loading a saved Tune-A-Video checkpoint +```python +import torch +from diffusers import DiffusionPipeline, DDIMScheduler +from diffusers.utils import export_to_video +from PIL import Image + +pipe = DiffusionPipeline.from_pretrained( + "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", torch_dtype=torch.float16 +).to("cuda") + +prompt = "A princess playing a guitar, modern disney style" +generator = torch.Generator(device="cuda").manual_seed(42) + +video_frames = pipe(prompt, video_length=3, generator=generator, num_inference_steps=50, output_type="np").frames + +# Saving to gif. +pil_frames = [Image.fromarray(frame) for frame in video_frames] +duration = len(pil_frames) / 8 +pil_frames[0].save( + "animation.gif", + save_all=True, + append_images=pil_frames[1:], # append rest of the images + duration=duration * 1000, # in milliseconds + loop=0, +) + +# Saving to video +video_path = export_to_video(video_frames) +``` + +Here are some sample outputs: + + + + + +
+ A princess playing a guitar, modern disney style +
+ A princess playing a guitar, modern disney style +
+ +## Available checkpoints + +* [Tune-A-Video-library/df-cpt-mo-di-bear-guitar](https://huggingface.co/Tune-A-Video-library/df-cpt-mo-di-bear-guitar) + +## TuneAVideoPipeline +[[autodoc]] TuneAVideoPipeline + - all + - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 42f352c029c8..f55bb69fada2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ "T2IAdapter", "T5FilmDecoder", "Transformer2DModel", + "Transformer3DModel", "UNet1DModel", "UNet2DConditionModel", "UNet2DModel", @@ -268,6 +269,7 @@ "StableUnCLIPPipeline", "TextToVideoSDPipeline", "TextToVideoZeroPipeline", + "TuneAVideoPipeline", "UnCLIPImageVariationPipeline", "UnCLIPPipeline", "UniDiffuserModel", @@ -443,6 +445,7 @@ T2IAdapter, T5FilmDecoder, Transformer2DModel, + Transformer3DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, @@ -606,6 +609,7 @@ StableUnCLIPPipeline, TextToVideoSDPipeline, TextToVideoZeroPipeline, + TuneAVideoPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, UniDiffuserModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index a5d0066d5c40..6afbcb3b3514 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -30,6 +30,7 @@ _import_structure["prior_transformer"] = ["PriorTransformer"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformer_3d"] = ["Transformer3DModel"] _import_structure["transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unet_1d"] = ["UNet1DModel"] _import_structure["unet_2d"] = ["UNet2DModel"] @@ -55,6 +56,7 @@ from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel + from .transformer_3d import Transformer3DModel from .transformer_temporal import TransformerTemporalModel from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 80bf269fc4e3..38a12ac81ce2 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -757,7 +757,196 @@ def forward(self, input_tensor, temb, scale: float = 1.0): return output_tensor -# unet_rl.py +class Upsample3D(nn.Module): + """A 3D upsampling layer. Reshapes the input tensor to video like tensor, applies upsampling conv, + converts it back to the original shape. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size=None): + if hidden_states.shape[1] != self.channels: + raise ValueError( + f"Expected hidden_states tensor at dimension 1 to match the number of channels. Expected: {self.channels} but passed: {hidden_states.shape[1]}" + ) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # Inflate + video_length = hidden_states.shape[2] + # b c f h w -> (b f) c h w + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + hidden_states = hidden_states.flatten(0, 1) + + hidden_states = self.conv(hidden_states) + # Deflate + # (b f) c h w -> b c f h w (f=video_length) + hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + + return hidden_states + + +class Downsample3D(nn.Module): + """A 3D downsampling layer. Reshapes the input tensor to video like tensor, applies conv, + converts it back to the original shape. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=padding) + + def forward(self, hidden_states): + video_length = hidden_states.shape[2] + # b c f h w -> (b f) c h w + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + hidden_states = hidden_states.flatten(0, 1) + # Conv + hidden_states = self.conv(hidden_states) + # (b f) c h w -> b c f h w (f=video_length) + hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. Used specifically for video like data. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. + output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. + """ + + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + eps=1e-6, + non_linearity="swish", + output_scale_factor=1.0, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.output_scale_factor = output_scale_factor + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.use_in_shortcut = self.in_channels != self.out_channels + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + video_length = hidden_states.shape[2] + # b c f h w -> (b f) c h w + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + hidden_states = hidden_states.flatten(0, 1) + hidden_states = self.conv1(hidden_states) + # (b f) c h w -> b c f h w (f=video_length + hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + video_length = hidden_states.shape[2] + # b c f h w -> (b f) c h w + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + hidden_states = hidden_states.flatten(0, 1) + hidden_states = self.conv2(hidden_states) + # (b f) c h w -> b c f h w (f=video_length) + hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + + if self.conv_shortcut is not None: + video_length = input_tensor.shape[2] + # "b c f h w -> (b f) c h w" + input_tensor = input_tensor.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + input_tensor = input_tensor.flatten(0, 1) + input_tensor = self.conv_shortcut(input_tensor) + # "(b f) c h w -> b c f h w"; f=video_length + input_tensor = input_tensor.reshape([-1, video_length, *input_tensor.shape[1:]]) + input_tensor = input_tensor.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: if len(tensor.shape) == 2: return tensor[:, :, None] diff --git a/src/diffusers/models/transformer_3d.py b/src/diffusers/models/transformer_3d.py new file mode 100644 index 000000000000..181b3c47b9e1 --- /dev/null +++ b/src/diffusers/models/transformer_3d.py @@ -0,0 +1,269 @@ +# Based on the TuneAVideo Transformer3DModel from Showlab: https://arxiv.org/abs/2212.11565 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .attention import Attention, FeedForward +from .modeling_utils import ModelMixin + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, num_frames, height, width)`. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + """ + Transformer model for a video-like data. + + When input is continuous: First, project the input (aka embedding) and reshape to b, h * w, c. Then apply the + sparse 3d transformer action. Finally, reshape to video again. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + norm_num_groups: (`int`, *optional*, defaults to 32): + The number of norm groups for the group norm. + cross_attention_dim (`int`, *optional*): + The number of encoder_hidden_states dimensions to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: int = 1280, + activation_fn: str = "geglu", + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicSparse3DTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, num_frames, height, width)`): + Input hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_3d.Transformer3DModelOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_3d.Transformer3DModelOutput`] or `tuple`: + [`~models.transformer_3d.Transformer3DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # Input + if not hidden_states.dim() == 5: + raise ValueError(f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}.") + video_length = hidden_states.shape[2] + hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + hidden_states = hidden_states.flatten(0, 1) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=video_length, dim=0) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + ) + + # Output + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + + output = hidden_states + residual + + # "(b f) c h w -> b c f h w"; f=video_length + output = output.reshape([-1, video_length, *output.shape[1:]]) + output = output.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicSparse3DTransformerBlock(nn.Module): + r""" + A modified basic Transformer block designed for use with Text to Video models. Currently only used by Tune A Video + pipeline with attn1 processor set to the TuneAVideoAttnProcessor. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int = 1280, + activation_fn: str = "geglu", + ): + super().__init__() + + # Temporal-Attention. + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + cross_attention_dim=None, + upcast_attention=False, + ) + self.norm1 = nn.LayerNorm(dim) + + # Cross-Attn + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=False, + ) + + self.norm2 = nn.LayerNorm(dim) + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + self.attn_temp = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=False, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = nn.LayerNorm(dim) + + def forward( + self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None + ): + # SparseCausal-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = ( + self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + ) + + norm_hidden_states = self.norm2(hidden_states) + hidden_states = ( + self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + d = hidden_states.shape[1] + # (b f) d c -> b f d c -> b d f c -> (b d) f c + hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) + hidden_states = hidden_states.movedim((0, 1, 2, 3), (0, 2, 1, 3)) + hidden_states = hidden_states.flatten(0, 1) + norm_hidden_states = self.norm_temp(hidden_states) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + # (b d) f c -> b d f c -> b f d c -> (b f) d c + hidden_states = hidden_states.reshape([-1, d, *hidden_states.shape[1:]]) + hidden_states = hidden_states.movedim((0, 1, 2, 3), (0, 2, 1, 3)) + hidden_states = hidden_states.flatten(0, 1) + + return hidden_states diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 180ae0dc1a81..68501cf368ed 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -16,8 +16,9 @@ from torch import nn from ..utils.torch_utils import apply_freeu -from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from .resnet import Downsample2D, Downsample3D, ResnetBlock2D, ResnetBlock3D, TemporalConvLayer, Upsample2D, Upsample3D from .transformer_2d import Transformer2DModel +from .transformer_3d import Transformer3DModel from .transformer_temporal import TransformerTemporalModel @@ -74,6 +75,34 @@ def get_down_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) + elif down_block_type == "DownBlockInflated3D": + return DownBlockInflated3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + ) + elif down_block_type == "CrossAttnDownBlockInflated3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockInflated3D") + return CrossAttnDownBlockInflated3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) raise ValueError(f"{down_block_type} does not exist.") @@ -133,9 +162,82 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, ) + elif up_block_type == "UpBlockInflated3D": + return UpBlockInflated3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + ) + elif up_block_type == "CrossAttnUpBlockInflated3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockInflated3D") + return CrossAttnUpBlockInflated3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) + raise ValueError(f"{up_block_type} does not exist.") +def get_mid_block( + mid_block_type, + in_channels, + temb_channels, + resnet_eps, + resnet_act_fn, + resnet_groups, + num_attention_heads, + output_scale_factor, + cross_attention_dim, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if mid_block_type == "UNetMidBlock3DCrossAttn": + return UNetMidBlock3DCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif mid_block_type == "UNetMidBlockInflated3DCrossAttn": + return UNetMidBlockInflated3DCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + ) + raise ValueError(f"{mid_block_type} does not exist.") + + class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, @@ -724,3 +826,397 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si hidden_states = upsampler(hidden_states, upsample_size) return hidden_states + + +class UpBlockInflated3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class DownBlockInflated3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [Downsample3D(out_channels, out_channels=out_channels, padding=downsample_padding)] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockInflated3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + attentions.append( + Transformer3DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + # TODO - cross_attention_kwargs are not used + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnDownBlockInflated3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + attentions.append( + Transformer3DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [Downsample3D(out_channels, out_channels=out_channels, padding=downsample_padding)] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + # TODO - cross_attention_kwargs are not used. + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlockInflated3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + + attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer3DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 2ab1d4060e17..4e3cf511e882 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -34,11 +34,15 @@ from .transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, + CrossAttnDownBlockInflated3D, CrossAttnUpBlock3D, + CrossAttnUpBlockInflated3D, DownBlock3D, - UNetMidBlock3DCrossAttn, + DownBlockInflated3D, UpBlock3D, + UpBlockInflated3D, get_down_block, + get_mid_block, get_up_block, ) @@ -72,21 +76,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`): The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D",)`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock3DCrossAttn"`): The midblock to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + use_temporal_transformer (`bool`, defaults to `True`): + If `False`, skips the temporal attention layer before processing the input. num_attention_heads (`int`, *optional*): The number of attention heads. """ @@ -109,11 +116,13 @@ def __init__( layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, + mid_block_type: str = "UNetMidBlock3DCrossAttn", act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, + use_temporal_transformer: bool = True, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, ): super().__init__() @@ -153,10 +162,10 @@ def __init__( conv_in_kernel = 3 conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) - # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, 0) @@ -168,15 +177,18 @@ def __init__( act_fn=act_fn, ) - self.transformer_in = TransformerTemporalModel( - num_attention_heads=8, - attention_head_dim=attention_head_dim, - in_channels=block_out_channels[0], - num_layers=1, - ) + if use_temporal_transformer: + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + else: + self.transformer_in = None - # class embedding self.down_blocks = nn.ModuleList([]) + self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): @@ -202,12 +214,12 @@ def __init__( cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, - dual_cross_attention=False, ) self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlock3DCrossAttn( + self.mid_block = get_mid_block( + mid_block_type=mid_block_type, in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, @@ -216,14 +228,14 @@ def __init__( cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, - dual_cross_attention=False, ) - # count how many layers upsample the images + # count how many layers upsample the videos self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) output_channel = reversed_block_out_channels[0] @@ -271,6 +283,7 @@ def __init__( self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @@ -460,7 +473,19 @@ def set_default_attn_processor(self): self.set_attn_processor(processor, _remove_lora=True) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + if isinstance( + module, + ( + CrossAttnDownBlock3D, + DownBlock3D, + CrossAttnUpBlock3D, + UpBlock3D, + CrossAttnDownBlockInflated3D, + DownBlockInflated3D, + CrossAttnUpBlockInflated3D, + UpBlockInflated3D, + ), + ): module.gradient_checkpointing = value # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu @@ -515,7 +540,7 @@ def forward( Args: sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + The noisy input tensor with the following shape `(batch, channel, num_frames, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. @@ -581,6 +606,7 @@ def forward( timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + # num_frames is equivalent to video_length in TuneAVideo num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) @@ -592,19 +618,28 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(repeats=num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process - sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) - sample = self.conv_in(sample) + # If not using temporal transfomer, use num_frames in unet + if self.transformer_in: + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) - sample = self.transformer_in( - sample, - num_frames=num_frames, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] + sample = self.transformer_in( + sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs + ).sample + else: + video_length = sample.shape[2] + # b c f h w -> (b f) c h w + sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + sample = sample.flatten(0, 1) + sample = self.conv_in(sample) + # (b f) c h w -> b c f h w (f=video_length) + sample = sample.reshape([-1, video_length, *sample.shape[1:]]) + sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) # 3. down down_block_res_samples = (sample,) @@ -685,10 +720,19 @@ def forward( sample = self.conv_norm_out(sample) sample = self.conv_act(sample) - sample = self.conv_out(sample) - # reshape to (batch, channel, framerate, width, height) - sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + if self.transformer_in: + sample = self.conv_out(sample) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + else: + video_length = sample.shape[2] + # b c f h w -> (b f) c h w + sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + sample = sample.flatten(0, 1) + sample = self.conv_out(sample) + # (b f) c h w -> b c f h w (f=video_length) + sample = sample.reshape([-1, video_length, *sample.shape[1:]]) + sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) if not return_dict: return (sample,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 19fe2f72d447..07ff7b1ccb41 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -157,6 +157,7 @@ "TextToVideoZeroPipeline", "VideoToVideoSDPipeline", ] + _import_structure["tune_a_video"] = ["TuneAVideoPipeline"] _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] _import_structure["unidiffuser"] = [ "ImageTextPipelineOutput", @@ -203,7 +204,6 @@ "StableDiffusionOnnxPipeline", ] ) - try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() @@ -374,6 +374,7 @@ TextToVideoZeroPipeline, VideoToVideoSDPipeline, ) + from .tune_a_video import TuneAVideoPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ( ImageTextPipelineOutput, diff --git a/src/diffusers/pipelines/tune_a_video/__init__.py b/src/diffusers/pipelines/tune_a_video/__init__.py new file mode 100644 index 000000000000..d409643f6c35 --- /dev/null +++ b/src/diffusers/pipelines/tune_a_video/__init__.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import torch + +from ...utils import ( + BaseOutput, + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, +) + + +@dataclass +class TuneAVideoPipelineOutput(BaseOutput): + """ + Output class for text to video pipelines. + + Args: + frames (`List[np.ndarray]` or `torch.FloatTensor`) + List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as + a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list + denotes the video length i.e., the number of frames. + """ + + frames: Union[List[np.ndarray], torch.FloatTensor] + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_tune_a_video import TuneAVideoPipeline # noqa: F401 diff --git a/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py b/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py new file mode 100644 index 000000000000..74a607cd3975 --- /dev/null +++ b/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py @@ -0,0 +1,686 @@ +# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL +from ...models.unet_3d_condition import UNet3DConditionModel +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from . import TuneAVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import UNet3DConditionModel, TuneAVideoPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_video + + >>> pretrained_model_path = "nitrosocke/mo-di-diffusion" + >>> scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + >>> unet = UNet3DConditionModel.from_pretrained( + ... "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", subfolder="unet", torch_dtype=torch.float16 + ... ).to("cuda") + >>> pipe = TuneAVideoPipeline.from_pretrained( + ... pretrained_model_path, unet=unet, scheduler=scheduler, torch_dtype=torch.float16 + ... ).to("cuda") + + >>> prompt = "a magical princess is playing guitar, modern disney style" + >>> video_frames = pipe( + ... prompt, + ... video_length=4, + ... height=256, + ... width=256, + ... num_inference_steps=38, + ... guidance_scale=12.5, + ... output_type="np", + ... ).frames + + >>> video_frames = pipe(prompt).frames + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +def tensor2vid(video: torch.Tensor) -> List[np.ndarray]: + # This code is modified from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + i, c, f, h, w = video.shape + images = video.permute(2, 3, 0, 4, 1).reshape( + f, h, i * w, c + ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) + images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) + images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c + return images + + +class TuneAVideoAttnProcessor: + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query.shape[-1] + query = attn.head_to_batch_dim(query) + + if attn.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + # key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + key = key.reshape([-1, video_length, *key.shape[1:]]) + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + # key = rearrange(key, "b f d c -> (b f) d c") + key = key.flatten(0, 1) + + # value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + value = value.reshape([-1, video_length, *value.shape[1:]]) + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + # value = rearrange(value, "b f d c -> (b f) d c") + value = value.flatten(0, 1) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class TuneAVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Same as Stable Diffusion 2. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + attn_processor_dict = self.unet.attn_processors + for key in attn_processor_dict.keys(): + # Only the Transformer3DModels attn1 attn processor is TuneAVideoAttnProcessor. + if "attn1" in key: + attn_processor_dict[key] = TuneAVideoAttnProcessor() + self.unet.set_attn_processor(attn_processor_dict) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + # latents = rearrange(latents, "b c f h w -> (b f) c h w") + latents = latents.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + latents = latents.flatten(0, 1) + video = self.vae.decode(latents).sample + # video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = video.reshape([-1, video_length, *video.shape[1:]]) + video = video.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + video_length, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + video_length: int = 8, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + video_length (`int`, *optional*, defaults to 8): + The number of video frames that are generated. Defaults to 8 frames which at 8 frames per seconds + amounts to 1 second of video. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to be generated for each input prompt + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.tune_a_video.TuneAVideoPipelineOutput`] instead of a plain + tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.TuneAVideoPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.TuneAVideoPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated frames. + """ + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + video = self.decode_latents(latents) + + # output a video tensor if output_type is not np. + video = torch.from_numpy(video) + if output_type == "np": + video = tensor2vid(video) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return video + + return TuneAVideoPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 717db3bbdb34..47ebf199a210 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1235,10 +1235,12 @@ def forward( deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", - "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated " - " and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only" - " be used for ControlNet. Please make sure use" - " `down_intrablock_additional_residuals` instead. ", + ( + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated " + " and will be removed in diffusers 1.3.0. `down_block_additional_residuals`" + " should only be used for ControlNet. Please make sure use" + " `down_intrablock_additional_residuals` instead. " + ), standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8e95dde52caf..9fd6f116e878 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -152,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_note_seq_objects.py b/src/diffusers/utils/dummy_torch_and_note_seq_objects.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d831cc49b495..58f74a9e969a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1112,6 +1112,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class TuneAVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/tune_a_video/__init__.py b/tests/pipelines/tune_a_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/tune_a_video/test_tune_a_video.py b/tests/pipelines/tune_a_video/test_tune_a_video.py new file mode 100644 index 000000000000..778f0b781cb8 --- /dev/null +++ b/tests/pipelines/tune_a_video/test_tune_a_video.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDIMScheduler, TuneAVideoPipeline, UNet3DConditionModel +from diffusers.utils.testing_utils import load_numpy, skip_mps, slow, numpy_cosine_similarity_distance + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +@skip_mps +class TuneAVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = TuneAVideoPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + # No `output_type`. + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback", + "callback_steps", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet3DConditionModel( + block_out_channels=(32, 64, 64, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=( + "CrossAttnDownBlockInflated3D", + "CrossAttnDownBlockInflated3D", + "CrossAttnDownBlockInflated3D", + "DownBlockInflated3D", + ), + up_block_types=( + "UpBlockInflated3D", + "CrossAttnUpBlockInflated3D", + "CrossAttnUpBlockInflated3D", + "CrossAttnUpBlockInflated3D", + ), + mid_block_type="UNetMidBlockInflated3DCrossAttn", + cross_attention_dim=32, + attention_head_dim=4, + use_temporal_transformer=False, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "pt", + } + return inputs + + def test_tune_a_video_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = TuneAVideoPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "np" + frames = sd_pipe(**inputs).frames + image_slice = frames[0][-3:, -3:, -1] + + assert frames[0].shape == (64, 64, 3) + expected_slice = np.array([145, 150, 171, 177, 126, 115, 122, 122, 129]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skip( + reason="SlicedAttnProcessor doesn't work with this pipeline, need to add an equivalent processor for TuneAVideoAttnProcessor" + ) + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False) + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_consistent(self): + pass + + # (todo): sayakpaul + @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") + def test_inference_batch_single_identical(self): + pass + + @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.") + def test_num_images_per_prompt(self): + pass + + def test_progress_bar(self): + return super().test_progress_bar() + + @unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor") + def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): + pass + + @unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor") + def test_save_load_local(self, expected_max_difference=5e-4): + pass + + @unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor") + def test_save_load_optional_components(self, expected_max_difference=1e-4): + pass + + +@slow +@skip_mps +class TuneAVideoPipelineSlowTests(unittest.TestCase): + def test_full_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/tuneavideo-10step-mo-di.npy" + ) + + pipe = TuneAVideoPipeline.from_pretrained( + "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to("cuda") + + prompt = "A princess playing a guitar, modern disney style" + generator = torch.Generator(device="cuda").manual_seed(42) + + video_frames = pipe( + prompt, video_length=3, generator=generator, num_inference_steps=10, output_type="pt" + ).frames + video = video_frames.cpu().numpy() + + numpy_cosine_similarity_distance(expected_video.flatten(), video.flatten()) < 5e-2 + + def test_two_step_model(self): + expected_video = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/tuneavideo-2step-mo-di.npy" + ) + + pipe = TuneAVideoPipeline.from_pretrained( + "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to("cuda") + + prompt = "A princess playing a guitar, modern disney style" + generator = torch.Generator(device="cuda").manual_seed(42) + + video_frames = pipe( + prompt, video_length=3, generator=generator, num_inference_steps=2, output_type="pt" + ).frames + video = video_frames.cpu().numpy() + + numpy_cosine_similarity_distance(expected_video.flatten(), video.flatten()) < 5e-2