diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 88da548bd597..bf6bd39953e7 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -200,6 +200,8 @@
title: Tiny AutoEncoder
- local: api/models/transformer2d
title: Transformer2D
+ - local: api/models/transformer3d
+ title: Transformer3D
- local: api/models/transformer_temporal
title: Transformer Temporal
- local: api/models/prior_transformer
@@ -316,6 +318,8 @@
title: Text-to-video
- local: api/pipelines/text_to_video_zero
title: Text2Video-Zero
+ - local: api/pipelines/tune_a_video
+ title: Tune-A-Video
- local: api/pipelines/unclip
title: UnCLIP
- local: api/pipelines/latent_diffusion_uncond
diff --git a/docs/source/en/api/models/transformer3d.md b/docs/source/en/api/models/transformer3d.md
new file mode 100644
index 000000000000..4f73573006da
--- /dev/null
+++ b/docs/source/en/api/models/transformer3d.md
@@ -0,0 +1,11 @@
+# Transformer3D
+
+The Transformer2D model extended for video-like data.
+
+## Transformer3DModel
+
+[[autodoc]] Transformer3DModel
+
+## Transformer3DModelOutput
+
+[[autodoc]] models.transformer_3d.Transformer3DModelOutput
diff --git a/docs/source/en/api/pipelines/tune_a_video.mdx b/docs/source/en/api/pipelines/tune_a_video.mdx
new file mode 100644
index 000000000000..202051251f15
--- /dev/null
+++ b/docs/source/en/api/pipelines/tune_a_video.mdx
@@ -0,0 +1,122 @@
+
+
+# Tune-A-Video
+
+## Overview
+
+[Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation](https://arxiv.org/abs/2212.11565) by Jay Zhangjie Wu, Yixiao Ge, Xintao Wang, Stan Weixian Lei, Yuchao Gu, Yufei Shi, Wynne Hsu, Ying Shan, Xiaohu Qie, Mike Zheng Shou
+The abstract of the paper is the following:
+
+*To replicate the success of text-to-image (T2I) generation, recent works employ large-scale video datasets to train a text-to-video (T2V) generator. Despite their promising results, such paradigm is computationally expensive. In this work, we propose a new T2V generation setting—One-Shot Video Tuning, where only one text-video pair is presented. Our model is built on state-of-the-art T2I diffusion models pre-trained on massive image data. We make two key observations: 1) T2I models can generate still images that represent verb terms; 2) extending T2I models to generate multiple images concurrently exhibits surprisingly good content consistency. To further learn continuous motion, we introduce Tune-A-Video, which involves a tailored spatio-temporal attention mechanism and an efficient one-shot tuning strategy. At inference, we employ DDIM inversion to provide structure guidance for sampling. Extensive qualitative and numerical experiments demonstrate the remarkable ability of our method across various applications.*
+
+Resources:
+
+* [GitHub repository](https://github.com/showlab/Tune-A-Video)
+* [🤗 Spaces](https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-Training-UI)
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Demo
+|---|---|:---:|
+| [TuneAVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py) | *Text-to-Video Generation* | [🤗 Spaces](https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-inference)
+
+## Usage example
+
+### Loading with a pre-existing Text2Image checkpoint
+```python
+import torch
+from diffusers import TuneAVideoPipeline, DDIMScheduler, UNet3DConditionModel
+from diffusers.utils import export_to_video
+from PIL import Image
+
+# Use any pretrained Text2Image checkpoint based on stable diffusion
+pretrained_model_path = "nitrosocke/mo-di-diffusion"
+unet = UNet3DConditionModel.from_pretrained(
+ "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", subfolder="unet", torch_dtype=torch.float16
+).to("cuda")
+
+pipe = TuneAVideoPipeline.from_pretrained(pretrained_model_path, unet=unet, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A princess playing a guitar, modern disney style"
+generator = torch.Generator(device="cuda").manual_seed(42)
+
+video_frames = pipe(prompt, video_length=3, generator=generator, num_inference_steps=50, output_type="np").frames
+
+# Saving to gif.
+pil_frames = [Image.fromarray(frame) for frame in video_frames]
+duration = len(pil_frames) / 8
+pil_frames[0].save(
+ "animation.gif",
+ save_all=True,
+ append_images=pil_frames[1:], # append rest of the images
+ duration=duration * 1000, # in milliseconds
+ loop=0,
+)
+
+# Saving to video
+video_path = export_to_video(video_frames)
+```
+
+### Loading a saved Tune-A-Video checkpoint
+```python
+import torch
+from diffusers import DiffusionPipeline, DDIMScheduler
+from diffusers.utils import export_to_video
+from PIL import Image
+
+pipe = DiffusionPipeline.from_pretrained(
+ "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", torch_dtype=torch.float16
+).to("cuda")
+
+prompt = "A princess playing a guitar, modern disney style"
+generator = torch.Generator(device="cuda").manual_seed(42)
+
+video_frames = pipe(prompt, video_length=3, generator=generator, num_inference_steps=50, output_type="np").frames
+
+# Saving to gif.
+pil_frames = [Image.fromarray(frame) for frame in video_frames]
+duration = len(pil_frames) / 8
+pil_frames[0].save(
+ "animation.gif",
+ save_all=True,
+ append_images=pil_frames[1:], # append rest of the images
+ duration=duration * 1000, # in milliseconds
+ loop=0,
+)
+
+# Saving to video
+video_path = export_to_video(video_frames)
+```
+
+Here are some sample outputs:
+
+
+
+
+ A princess playing a guitar, modern disney style
+
+
+ |
+
+
+
+## Available checkpoints
+
+* [Tune-A-Video-library/df-cpt-mo-di-bear-guitar](https://huggingface.co/Tune-A-Video-library/df-cpt-mo-di-bear-guitar)
+
+## TuneAVideoPipeline
+[[autodoc]] TuneAVideoPipeline
+ - all
+ - __call__
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 42f352c029c8..f55bb69fada2 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -84,6 +84,7 @@
"T2IAdapter",
"T5FilmDecoder",
"Transformer2DModel",
+ "Transformer3DModel",
"UNet1DModel",
"UNet2DConditionModel",
"UNet2DModel",
@@ -268,6 +269,7 @@
"StableUnCLIPPipeline",
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
+ "TuneAVideoPipeline",
"UnCLIPImageVariationPipeline",
"UnCLIPPipeline",
"UniDiffuserModel",
@@ -443,6 +445,7 @@
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
+ Transformer3DModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
@@ -606,6 +609,7 @@
StableUnCLIPPipeline,
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
+ TuneAVideoPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
UniDiffuserModel,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index a5d0066d5c40..6afbcb3b3514 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -30,6 +30,7 @@
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
+ _import_structure["transformer_3d"] = ["Transformer3DModel"]
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unet_1d"] = ["UNet1DModel"]
_import_structure["unet_2d"] = ["UNet2DModel"]
@@ -55,6 +56,7 @@
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
+ from .transformer_3d import Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 80bf269fc4e3..38a12ac81ce2 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -757,7 +757,196 @@ def forward(self, input_tensor, temb, scale: float = 1.0):
return output_tensor
-# unet_rl.py
+class Upsample3D(nn.Module):
+ """A 3D upsampling layer. Reshapes the input tensor to video like tensor, applies upsampling conv,
+ converts it back to the original shape.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ """
+
+ def __init__(self, channels, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, hidden_states, output_size=None):
+ if hidden_states.shape[1] != self.channels:
+ raise ValueError(
+ f"Expected hidden_states tensor at dimension 1 to match the number of channels. Expected: {self.channels} but passed: {hidden_states.shape[1]}"
+ )
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ # Inflate
+ video_length = hidden_states.shape[2]
+ # b c f h w -> (b f) c h w
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ hidden_states = hidden_states.flatten(0, 1)
+
+ hidden_states = self.conv(hidden_states)
+ # Deflate
+ # (b f) c h w -> b c f h w (f=video_length)
+ hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]])
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ """A 3D downsampling layer. Reshapes the input tensor to video like tensor, applies conv,
+ converts it back to the original shape.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ """
+
+ def __init__(self, channels, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=padding)
+
+ def forward(self, hidden_states):
+ video_length = hidden_states.shape[2]
+ # b c f h w -> (b f) c h w
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ hidden_states = hidden_states.flatten(0, 1)
+ # Conv
+ hidden_states = self.conv(hidden_states)
+ # (b f) c h w -> b c f h w (f=video_length)
+ hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]])
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ r"""
+ A Resnet block. Used specifically for video like data.
+
+ Parameters:
+ in_channels (`int`): The number of channels in the input.
+ out_channels (`int`, *optional*, default to be `None`):
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
+ """
+
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ eps=1e-6,
+ non_linearity="swish",
+ output_scale_factor=1.0,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.output_scale_factor = output_scale_factor
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.use_in_shortcut = self.in_channels != self.out_channels
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ video_length = hidden_states.shape[2]
+ # b c f h w -> (b f) c h w
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ hidden_states = hidden_states.flatten(0, 1)
+ hidden_states = self.conv1(hidden_states)
+ # (b f) c h w -> b c f h w (f=video_length
+ hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]])
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+
+ video_length = hidden_states.shape[2]
+ # b c f h w -> (b f) c h w
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ hidden_states = hidden_states.flatten(0, 1)
+ hidden_states = self.conv2(hidden_states)
+ # (b f) c h w -> b c f h w (f=video_length)
+ hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]])
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+
+ if self.conv_shortcut is not None:
+ video_length = input_tensor.shape[2]
+ # "b c f h w -> (b f) c h w"
+ input_tensor = input_tensor.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ input_tensor = input_tensor.flatten(0, 1)
+ input_tensor = self.conv_shortcut(input_tensor)
+ # "(b f) c h w -> b c f h w"; f=video_length
+ input_tensor = input_tensor.reshape([-1, video_length, *input_tensor.shape[1:]])
+ input_tensor = input_tensor.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
if len(tensor.shape) == 2:
return tensor[:, :, None]
diff --git a/src/diffusers/models/transformer_3d.py b/src/diffusers/models/transformer_3d.py
new file mode 100644
index 000000000000..181b3c47b9e1
--- /dev/null
+++ b/src/diffusers/models/transformer_3d.py
@@ -0,0 +1,269 @@
+# Based on the TuneAVideo Transformer3DModel from Showlab: https://arxiv.org/abs/2212.11565
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .attention import Attention, FeedForward
+from .modeling_utils import ModelMixin
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, num_frames, height, width)`.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ """
+ Transformer model for a video-like data.
+
+ When input is continuous: First, project the input (aka embedding) and reshape to b, h * w, c. Then apply the
+ sparse 3d transformer action. Finally, reshape to video again.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88):
+ The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ norm_num_groups: (`int`, *optional*, defaults to 32):
+ The number of norm groups for the group norm.
+ cross_attention_dim (`int`, *optional*):
+ The number of encoder_hidden_states dimensions to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
+ Activation function to be used in feed-forward.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: int = 1280,
+ activation_fn: str = "geglu",
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicSparse3DTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, num_frames, height, width)`):
+ Input hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_3d.Transformer3DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.transformer_3d.Transformer3DModelOutput`] or `tuple`:
+ [`~models.transformer_3d.Transformer3DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # Input
+ if not hidden_states.dim() == 5:
+ raise ValueError(f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}.")
+ video_length = hidden_states.shape[2]
+ hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ hidden_states = hidden_states.flatten(0, 1)
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=video_length, dim=0)
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length,
+ )
+
+ # Output
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = self.proj_out(hidden_states)
+
+ output = hidden_states + residual
+
+ # "(b f) c h w -> b c f h w"; f=video_length
+ output = output.reshape([-1, video_length, *output.shape[1:]])
+ output = output.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicSparse3DTransformerBlock(nn.Module):
+ r"""
+ A modified basic Transformer block designed for use with Text to Video models. Currently only used by Tune A Video
+ pipeline with attn1 processor set to the TuneAVideoAttnProcessor.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: int = 1280,
+ activation_fn: str = "geglu",
+ ):
+ super().__init__()
+
+ # Temporal-Attention.
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=False,
+ cross_attention_dim=None,
+ upcast_attention=False,
+ )
+ self.norm1 = nn.LayerNorm(dim)
+
+ # Cross-Attn
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=False,
+ upcast_attention=False,
+ )
+
+ self.norm2 = nn.LayerNorm(dim)
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ # Temp-Attn
+ self.attn_temp = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=False,
+ upcast_attention=False,
+ )
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+ self.norm_temp = nn.LayerNorm(dim)
+
+ def forward(
+ self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
+ ):
+ # SparseCausal-Attention
+ norm_hidden_states = self.norm1(hidden_states)
+
+ hidden_states = (
+ self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+ )
+
+ norm_hidden_states = self.norm2(hidden_states)
+ hidden_states = (
+ self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask)
+ + hidden_states
+ )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ # Temporal-Attention
+ d = hidden_states.shape[1]
+ # (b f) d c -> b f d c -> b d f c -> (b d) f c
+ hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]])
+ hidden_states = hidden_states.movedim((0, 1, 2, 3), (0, 2, 1, 3))
+ hidden_states = hidden_states.flatten(0, 1)
+ norm_hidden_states = self.norm_temp(hidden_states)
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ # (b d) f c -> b d f c -> b f d c -> (b f) d c
+ hidden_states = hidden_states.reshape([-1, d, *hidden_states.shape[1:]])
+ hidden_states = hidden_states.movedim((0, 1, 2, 3), (0, 2, 1, 3))
+ hidden_states = hidden_states.flatten(0, 1)
+
+ return hidden_states
diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py
index 180ae0dc1a81..68501cf368ed 100644
--- a/src/diffusers/models/unet_3d_blocks.py
+++ b/src/diffusers/models/unet_3d_blocks.py
@@ -16,8 +16,9 @@
from torch import nn
from ..utils.torch_utils import apply_freeu
-from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
+from .resnet import Downsample2D, Downsample3D, ResnetBlock2D, ResnetBlock3D, TemporalConvLayer, Upsample2D, Upsample3D
from .transformer_2d import Transformer2DModel
+from .transformer_3d import Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
@@ -74,6 +75,34 @@ def get_down_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
+ elif down_block_type == "DownBlockInflated3D":
+ return DownBlockInflated3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "CrossAttnDownBlockInflated3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockInflated3D")
+ return CrossAttnDownBlockInflated3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ )
raise ValueError(f"{down_block_type} does not exist.")
@@ -133,9 +162,82 @@ def get_up_block(
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
)
+ elif up_block_type == "UpBlockInflated3D":
+ return UpBlockInflated3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ )
+ elif up_block_type == "CrossAttnUpBlockInflated3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockInflated3D")
+ return CrossAttnUpBlockInflated3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ )
+
raise ValueError(f"{up_block_type} does not exist.")
+def get_mid_block(
+ mid_block_type,
+ in_channels,
+ temb_channels,
+ resnet_eps,
+ resnet_act_fn,
+ resnet_groups,
+ num_attention_heads,
+ output_scale_factor,
+ cross_attention_dim,
+ dual_cross_attention=False,
+ use_linear_projection=True,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+):
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ return UNetMidBlock3DCrossAttn(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ resnet_groups=resnet_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif mid_block_type == "UNetMidBlockInflated3DCrossAttn":
+ return UNetMidBlockInflated3DCrossAttn(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ resnet_groups=resnet_groups,
+ )
+ raise ValueError(f"{mid_block_type} does not exist.")
+
+
class UNetMidBlock3DCrossAttn(nn.Module):
def __init__(
self,
@@ -724,3 +826,397 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
+
+
+class UpBlockInflated3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class DownBlockInflated3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [Downsample3D(out_channels, out_channels=out_channels, padding=downsample_padding)]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, num_frames=1):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlockInflated3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ attentions.append(
+ Transformer3DModel(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
+ # TODO(Patrick, William) - attention mask is not used
+ # TODO - cross_attention_kwargs are not used
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ ).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class CrossAttnDownBlockInflated3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ attentions.append(
+ Transformer3DModel(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [Downsample3D(out_channels, out_channels=out_channels, padding=downsample_padding)]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
+ # TODO(Patrick, William) - attention mask is not used
+ # TODO - cross_attention_kwargs are not used.
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ ).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class UNetMidBlockInflated3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ ]
+
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ Transformer3DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
+ hidden_states = self.resnets[0](hidden_states, temb)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ ).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py
index 2ab1d4060e17..4e3cf511e882 100644
--- a/src/diffusers/models/unet_3d_condition.py
+++ b/src/diffusers/models/unet_3d_condition.py
@@ -34,11 +34,15 @@
from .transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
CrossAttnDownBlock3D,
+ CrossAttnDownBlockInflated3D,
CrossAttnUpBlock3D,
+ CrossAttnUpBlockInflated3D,
DownBlock3D,
- UNetMidBlock3DCrossAttn,
+ DownBlockInflated3D,
UpBlock3D,
+ UpBlockInflated3D,
get_down_block,
+ get_mid_block,
get_up_block,
)
@@ -72,21 +76,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
The tuple of downsample blocks to use.
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D",)`):
The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock3DCrossAttn"`): The midblock to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ use_temporal_transformer (`bool`, defaults to `True`):
+ If `False`, skips the temporal attention layer before processing the input.
num_attention_heads (`int`, *optional*): The number of attention heads.
"""
@@ -109,11 +116,13 @@ def __init__(
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1024,
attention_head_dim: Union[int, Tuple[int]] = 64,
+ use_temporal_transformer: bool = True,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
):
super().__init__()
@@ -153,10 +162,10 @@ def __init__(
conv_in_kernel = 3
conv_out_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
+
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
-
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], True, 0)
@@ -168,15 +177,18 @@ def __init__(
act_fn=act_fn,
)
- self.transformer_in = TransformerTemporalModel(
- num_attention_heads=8,
- attention_head_dim=attention_head_dim,
- in_channels=block_out_channels[0],
- num_layers=1,
- )
+ if use_temporal_transformer:
+ self.transformer_in = TransformerTemporalModel(
+ num_attention_heads=8,
+ attention_head_dim=attention_head_dim,
+ in_channels=block_out_channels[0],
+ num_layers=1,
+ )
+ else:
+ self.transformer_in = None
- # class embedding
self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(num_attention_heads, int):
@@ -202,12 +214,12 @@ def __init__(
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
- dual_cross_attention=False,
)
self.down_blocks.append(down_block)
# mid
- self.mid_block = UNetMidBlock3DCrossAttn(
+ self.mid_block = get_mid_block(
+ mid_block_type=mid_block_type,
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
@@ -216,14 +228,14 @@ def __init__(
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
- dual_cross_attention=False,
)
- # count how many layers upsample the images
+ # count how many layers upsample the videos
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
output_channel = reversed_block_out_channels[0]
@@ -271,6 +283,7 @@ def __init__(
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
+
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@@ -460,7 +473,19 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ if isinstance(
+ module,
+ (
+ CrossAttnDownBlock3D,
+ DownBlock3D,
+ CrossAttnUpBlock3D,
+ UpBlock3D,
+ CrossAttnDownBlockInflated3D,
+ DownBlockInflated3D,
+ CrossAttnUpBlockInflated3D,
+ UpBlockInflated3D,
+ ),
+ ):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
@@ -515,7 +540,7 @@ def forward(
Args:
sample (`torch.FloatTensor`):
- The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
+ The noisy input tensor with the following shape `(batch, channel, num_frames, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
@@ -581,6 +606,7 @@ def forward(
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ # num_frames is equivalent to video_length in TuneAVideo
num_frames = sample.shape[2]
timesteps = timesteps.expand(sample.shape[0])
@@ -592,19 +618,28 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
# 2. pre-process
- sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
- sample = self.conv_in(sample)
+ # If not using temporal transfomer, use num_frames in unet
+ if self.transformer_in:
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
+
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
+ sample = self.conv_in(sample)
- sample = self.transformer_in(
- sample,
- num_frames=num_frames,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=False,
- )[0]
+ sample = self.transformer_in(
+ sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
+ ).sample
+ else:
+ video_length = sample.shape[2]
+ # b c f h w -> (b f) c h w
+ sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ sample = sample.flatten(0, 1)
+ sample = self.conv_in(sample)
+ # (b f) c h w -> b c f h w (f=video_length)
+ sample = sample.reshape([-1, video_length, *sample.shape[1:]])
+ sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
# 3. down
down_block_res_samples = (sample,)
@@ -685,10 +720,19 @@ def forward(
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
- sample = self.conv_out(sample)
-
# reshape to (batch, channel, framerate, width, height)
- sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
+ if self.transformer_in:
+ sample = self.conv_out(sample)
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
+ else:
+ video_length = sample.shape[2]
+ # b c f h w -> (b f) c h w
+ sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ sample = sample.flatten(0, 1)
+ sample = self.conv_out(sample)
+ # (b f) c h w -> b c f h w (f=video_length)
+ sample = sample.reshape([-1, video_length, *sample.shape[1:]])
+ sample = sample.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
if not return_dict:
return (sample,)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 19fe2f72d447..07ff7b1ccb41 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -157,6 +157,7 @@
"TextToVideoZeroPipeline",
"VideoToVideoSDPipeline",
]
+ _import_structure["tune_a_video"] = ["TuneAVideoPipeline"]
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
_import_structure["unidiffuser"] = [
"ImageTextPipelineOutput",
@@ -203,7 +204,6 @@
"StableDiffusionOnnxPipeline",
]
)
-
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
@@ -374,6 +374,7 @@
TextToVideoZeroPipeline,
VideoToVideoSDPipeline,
)
+ from .tune_a_video import TuneAVideoPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import (
ImageTextPipelineOutput,
diff --git a/src/diffusers/pipelines/tune_a_video/__init__.py b/src/diffusers/pipelines/tune_a_video/__init__.py
new file mode 100644
index 000000000000..d409643f6c35
--- /dev/null
+++ b/src/diffusers/pipelines/tune_a_video/__init__.py
@@ -0,0 +1,36 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import torch
+
+from ...utils import (
+ BaseOutput,
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+@dataclass
+class TuneAVideoPipelineOutput(BaseOutput):
+ """
+ Output class for text to video pipelines.
+
+ Args:
+ frames (`List[np.ndarray]` or `torch.FloatTensor`)
+ List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
+ a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
+ denotes the video length i.e., the number of frames.
+ """
+
+ frames: Union[List[np.ndarray], torch.FloatTensor]
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .pipeline_tune_a_video import TuneAVideoPipeline # noqa: F401
diff --git a/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py b/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py
new file mode 100644
index 000000000000..74a607cd3975
--- /dev/null
+++ b/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py
@@ -0,0 +1,686 @@
+# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL
+from ...models.unet_3d_condition import UNet3DConditionModel
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import logging, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from . import TuneAVideoPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+
+ >>> from diffusers import UNet3DConditionModel, TuneAVideoPipeline, DDIMScheduler
+ >>> from diffusers.utils import export_to_video
+
+ >>> pretrained_model_path = "nitrosocke/mo-di-diffusion"
+ >>> scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ >>> unet = UNet3DConditionModel.from_pretrained(
+ ... "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", subfolder="unet", torch_dtype=torch.float16
+ ... ).to("cuda")
+ >>> pipe = TuneAVideoPipeline.from_pretrained(
+ ... pretrained_model_path, unet=unet, scheduler=scheduler, torch_dtype=torch.float16
+ ... ).to("cuda")
+
+ >>> prompt = "a magical princess is playing guitar, modern disney style"
+ >>> video_frames = pipe(
+ ... prompt,
+ ... video_length=4,
+ ... height=256,
+ ... width=256,
+ ... num_inference_steps=38,
+ ... guidance_scale=12.5,
+ ... output_type="np",
+ ... ).frames
+
+ >>> video_frames = pipe(prompt).frames
+ >>> video_path = export_to_video(video_frames)
+ >>> video_path
+ ```
+"""
+
+
+def tensor2vid(video: torch.Tensor) -> List[np.ndarray]:
+ # This code is modified from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
+ i, c, f, h, w = video.shape
+ images = video.permute(2, 3, 0, 4, 1).reshape(
+ f, h, i * w, c
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
+ return images
+
+
+class TuneAVideoAttnProcessor:
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ if attn.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ former_frame_index = torch.arange(video_length) - 1
+ former_frame_index[0] = 0
+
+ # key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
+ key = key.reshape([-1, video_length, *key.shape[1:]])
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
+ # key = rearrange(key, "b f d c -> (b f) d c")
+ key = key.flatten(0, 1)
+
+ # value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
+ value = value.reshape([-1, video_length, *value.shape[1:]])
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
+ # value = rearrange(value, "b f d c -> (b f) d c")
+ value = value.flatten(0, 1)
+
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class TuneAVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Same as Stable Diffusion 2.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
+ [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
+ """
+
+ _optional_components = []
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet3DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ attn_processor_dict = self.unet.attn_processors
+ for key in attn_processor_dict.keys():
+ # Only the Transformer3DModels attn1 attn processor is TuneAVideoAttnProcessor.
+ if "attn1" in key:
+ attn_processor_dict[key] = TuneAVideoAttnProcessor()
+ self.unet.set_attn_processor(attn_processor_dict)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_videos_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def decode_latents(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ # latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ latents = latents.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ latents = latents.flatten(0, 1)
+ video = self.vae.decode(latents).sample
+ # video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = video.reshape([-1, video_length, *video.shape[1:]])
+ video = video.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4))
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ video_length,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ rand_device = "cpu" if device.type == "mps" else device
+
+ if isinstance(generator, list):
+ shape = (1,) + shape[1:]
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
+ for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ video_length: int = 8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ video_length (`int`, *optional*, defaults to 8):
+ The number of video frames that are generated. Defaults to 8 frames which at 8 frames per seconds
+ amounts to 1 second of video.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
+ usually at the expense of lower video quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to be generated for each input prompt
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.tune_a_video.TuneAVideoPipelineOutput`] instead of a plain
+ tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.TuneAVideoPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.TuneAVideoPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated frames.
+ """
+ # Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ video_length,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents_dtype = latents.dtype
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample.to(dtype=latents_dtype)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # Post-processing
+ video = self.decode_latents(latents)
+
+ # output a video tensor if output_type is not np.
+ video = torch.from_numpy(video)
+ if output_type == "np":
+ video = tensor2vid(video)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return video
+
+ return TuneAVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index 717db3bbdb34..47ebf199a210 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -1235,10 +1235,12 @@ def forward(
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
- " and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only"
- " be used for ControlNet. Please make sure use"
- " `down_intrablock_additional_residuals` instead. ",
+ (
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
+ " and will be removed in diffusers 1.3.0. `down_block_additional_residuals`"
+ " should only be used for ControlNet. Please make sure use"
+ " `down_intrablock_additional_residuals` instead. "
+ ),
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 8e95dde52caf..9fd6f116e878 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -152,6 +152,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class Transformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_note_seq_objects.py b/src/diffusers/utils/dummy_torch_and_note_seq_objects.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index d831cc49b495..58f74a9e969a 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1112,6 +1112,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class TuneAVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class UnCLIPImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/pipelines/tune_a_video/__init__.py b/tests/pipelines/tune_a_video/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/tune_a_video/test_tune_a_video.py b/tests/pipelines/tune_a_video/test_tune_a_video.py
new file mode 100644
index 000000000000..778f0b781cb8
--- /dev/null
+++ b/tests/pipelines/tune_a_video/test_tune_a_video.py
@@ -0,0 +1,227 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDIMScheduler, TuneAVideoPipeline, UNet3DConditionModel
+from diffusers.utils.testing_utils import load_numpy, skip_mps, slow, numpy_cosine_similarity_distance
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+torch.backends.cuda.matmul.allow_tf32 = False
+
+
+@skip_mps
+class TuneAVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = TuneAVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ # No `output_type`.
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback",
+ "callback_steps",
+ ]
+ )
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet3DConditionModel(
+ block_out_channels=(32, 64, 64, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=(
+ "CrossAttnDownBlockInflated3D",
+ "CrossAttnDownBlockInflated3D",
+ "CrossAttnDownBlockInflated3D",
+ "DownBlockInflated3D",
+ ),
+ up_block_types=(
+ "UpBlockInflated3D",
+ "CrossAttnUpBlockInflated3D",
+ "CrossAttnUpBlockInflated3D",
+ "CrossAttnUpBlockInflated3D",
+ ),
+ mid_block_type="UNetMidBlockInflated3DCrossAttn",
+ cross_attention_dim=32,
+ attention_head_dim=4,
+ use_temporal_transformer=False,
+ )
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=512,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_tune_a_video_default_case(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = TuneAVideoPipeline(**components)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["output_type"] = "np"
+ frames = sd_pipe(**inputs).frames
+ image_slice = frames[0][-3:, -3:, -1]
+
+ assert frames[0].shape == (64, 64, 3)
+ expected_slice = np.array([145, 150, 171, 177, 126, 115, 122, 122, 129])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skip(
+ reason="SlicedAttnProcessor doesn't work with this pipeline, need to add an equivalent processor for TuneAVideoAttnProcessor"
+ )
+ def test_attention_slicing_forward_pass(self):
+ self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
+
+ # (todo): sayakpaul
+ @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
+ def test_inference_batch_consistent(self):
+ pass
+
+ # (todo): sayakpaul
+ @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
+ def test_num_images_per_prompt(self):
+ pass
+
+ def test_progress_bar(self):
+ return super().test_progress_bar()
+
+ @unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor")
+ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
+ pass
+
+ @unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor")
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ pass
+
+ @unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor")
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ pass
+
+
+@slow
+@skip_mps
+class TuneAVideoPipelineSlowTests(unittest.TestCase):
+ def test_full_model(self):
+ expected_video = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/tuneavideo-10step-mo-di.npy"
+ )
+
+ pipe = TuneAVideoPipeline.from_pretrained(
+ "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", torch_dtype=torch.float16
+ )
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ pipe = pipe.to("cuda")
+
+ prompt = "A princess playing a guitar, modern disney style"
+ generator = torch.Generator(device="cuda").manual_seed(42)
+
+ video_frames = pipe(
+ prompt, video_length=3, generator=generator, num_inference_steps=10, output_type="pt"
+ ).frames
+ video = video_frames.cpu().numpy()
+
+ numpy_cosine_similarity_distance(expected_video.flatten(), video.flatten()) < 5e-2
+
+ def test_two_step_model(self):
+ expected_video = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/tuneavideo-2step-mo-di.npy"
+ )
+
+ pipe = TuneAVideoPipeline.from_pretrained(
+ "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", torch_dtype=torch.float16
+ )
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ pipe = pipe.to("cuda")
+
+ prompt = "A princess playing a guitar, modern disney style"
+ generator = torch.Generator(device="cuda").manual_seed(42)
+
+ video_frames = pipe(
+ prompt, video_length=3, generator=generator, num_inference_steps=2, output_type="pt"
+ ).frames
+ video = video_frames.cpu().numpy()
+
+ numpy_cosine_similarity_distance(expected_video.flatten(), video.flatten()) < 5e-2