-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[Pipeline] Port Tune-A-Video pipeline to diffusers #2455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ee82252
4ff86f1
a665a25
1ada882
c8a9780
5222dc9
dc6896a
851eafa
f8464fb
5b9420a
38352fc
8bcc552
e6d8a36
96d3f69
2943cfb
8050650
60903a5
8bbd946
967cc9a
72d7909
b2f31e2
b6e5798
ceda00b
a4f3fb3
9248ca1
4a1c144
846a328
2a9f5c6
a96fd08
c22137e
e1b25d5
f460df9
65e3b7e
8008b61
018b8b8
a9d6d47
81e3a5b
b2d59af
422791e
7996cee
a21f771
3df07fb
efd9dcc
dfaf7b6
665c676
ec19172
2effa4d
1960565
f62c2d6
9d764d0
e160d6a
54e3867
9ec0031
e03a59f
bd03a9e
660579e
b2e2ebf
28e0b53
a81ad8c
e4e9f88
ebffd82
2178bcf
fe976f9
a43346b
033d6ec
ddfe0c3
3021b05
33774d2
14e0ae9
9124f99
1408140
239ff3f
cfe4ab0
b60d651
278dc36
95885a8
b3bb4a6
0cda1d0
1a03d0b
1c56363
2ce47f1
e463a57
ae528d1
32e817d
bf0c330
d691344
9962a3e
366ac03
6a67890
65e8714
d2d9c27
c728e46
ff9277f
3ad8d2b
35bc2dc
8884912
c20e573
fd06295
e00fb86
50363f8
e6759c0
5be8330
61032cd
e549407
9db033b
8cdbca9
2a2af8f
738c50c
af53edf
cc83a61
1c01a66
ee67a25
86206b0
44c17bb
9e99fa7
a79011e
0495f77
86e5f84
950ee03
7112d02
3963d9c
8fbd857
7d6a153
ab527f9
29c4cb0
21f389a
b2e46f3
f064ead
673e628
466b519
8b039a4
fc4e610
20672c0
4958ba1
cd23221
72205d0
4aff39d
32016a8
4e5f85d
66434dd
bb5094b
769e15d
9f50db6
888ef44
046eb83
f3aee02
4c2ff4e
34e90f6
322da01
991f43b
00154ae
39c4888
1c895ec
c7a62d1
01de3a0
ad54258
cc49b1a
3786a7f
d3047d6
40bdf87
dff9536
a202f81
ec1d3ae
32d68d3
51c1e2d
a58ee65
ad9f32e
aa5d4ac
7c598f6
66cadd7
a2ae452
49a3002
64225be
fa7236f
84de858
7fdff31
52b74a4
c602867
1b91d31
ba8336b
47769e0
560b7a6
957393a
6fb6f19
cf89b7e
3ce3b05
c5cf07c
2517e57
277624c
8b90210
5589a69
ed6480e
fc8b890
756c182
b1da435
49441f3
1f074e3
94a42bf
8bf0172
0d92621
131b498
5c2b4a7
cbfe5a6
2298572
8ad40e1
1ea5fc4
f712505
c85486c
d9f96e3
97ecea6
918da5a
3249c4d
1788b7a
5b36028
584540f
bc3d7da
4d28536
c859e21
56efe3e
de3189f
2efc982
1c9862e
b01388b
3708620
45f69db
92de87d
cfd1d38
12c3b57
c75daf2
9d11a5e
c84f55e
1a0da2d
ad23710
76387d9
0019408
efcce2c
1c4b843
2d17ffc
515572e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # Transformer3D | ||
|
|
||
| The Transformer2D model extended for video-like data. | ||
|
|
||
| ## Transformer3DModel | ||
|
|
||
| [[autodoc]] Transformer3DModel | ||
|
|
||
| ## Transformer3DModelOutput | ||
|
|
||
| [[autodoc]] models.transformer_3d.Transformer3DModelOutput |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
| --> | ||
|
|
||
| # Tune-A-Video | ||
|
|
||
| ## Overview | ||
|
|
||
| [Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation](https://arxiv.org/abs/2212.11565) by Jay Zhangjie Wu, Yixiao Ge, Xintao Wang, Stan Weixian Lei, Yuchao Gu, Yufei Shi, Wynne Hsu, Ying Shan, Xiaohu Qie, Mike Zheng Shou | ||
| The abstract of the paper is the following: | ||
|
|
||
| *To replicate the success of text-to-image (T2I) generation, recent works employ large-scale video datasets to train a text-to-video (T2V) generator. Despite their promising results, such paradigm is computationally expensive. In this work, we propose a new T2V generation setting—One-Shot Video Tuning, where only one text-video pair is presented. Our model is built on state-of-the-art T2I diffusion models pre-trained on massive image data. We make two key observations: 1) T2I models can generate still images that represent verb terms; 2) extending T2I models to generate multiple images concurrently exhibits surprisingly good content consistency. To further learn continuous motion, we introduce Tune-A-Video, which involves a tailored spatio-temporal attention mechanism and an efficient one-shot tuning strategy. At inference, we employ DDIM inversion to provide structure guidance for sampling. Extensive qualitative and numerical experiments demonstrate the remarkable ability of our method across various applications.* | ||
|
|
||
| Resources: | ||
|
|
||
| * [GitHub repository](https://github.com/showlab/Tune-A-Video) | ||
| * [🤗 Spaces](https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-Training-UI) | ||
|
|
||
| ## Available Pipelines: | ||
|
|
||
| | Pipeline | Tasks | Demo | ||
| |---|---|:---:| | ||
| | [TuneAVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/tune_a_video/pipeline_tune_a_video.py) | *Text-to-Video Generation* | [🤗 Spaces](https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-inference) | ||
|
|
||
| ## Usage example | ||
|
|
||
| ### Loading with a pre-existing Text2Image checkpoint | ||
| ```python | ||
| import torch | ||
| from diffusers import TuneAVideoPipeline, DDIMScheduler, UNet3DConditionModel | ||
| from diffusers.utils import export_to_video | ||
| from PIL import Image | ||
|
|
||
| # Use any pretrained Text2Image checkpoint based on stable diffusion | ||
| pretrained_model_path = "nitrosocke/mo-di-diffusion" | ||
| unet = UNet3DConditionModel.from_pretrained( | ||
| "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", subfolder="unet", torch_dtype=torch.float16 | ||
| ).to("cuda") | ||
|
|
||
| pipe = TuneAVideoPipeline.from_pretrained(pretrained_model_path, unet=unet, torch_dtype=torch.float16).to("cuda") | ||
|
|
||
| prompt = "A princess playing a guitar, modern disney style" | ||
| generator = torch.Generator(device="cuda").manual_seed(42) | ||
|
|
||
| video_frames = pipe(prompt, video_length=3, generator=generator, num_inference_steps=50, output_type="np").frames | ||
|
|
||
| # Saving to gif. | ||
| pil_frames = [Image.fromarray(frame) for frame in video_frames] | ||
| duration = len(pil_frames) / 8 | ||
| pil_frames[0].save( | ||
| "animation.gif", | ||
| save_all=True, | ||
| append_images=pil_frames[1:], # append rest of the images | ||
| duration=duration * 1000, # in milliseconds | ||
| loop=0, | ||
| ) | ||
|
|
||
| # Saving to video | ||
| video_path = export_to_video(video_frames) | ||
| ``` | ||
|
|
||
| ### Loading a saved Tune-A-Video checkpoint | ||
| ```python | ||
| import torch | ||
| from diffusers import DiffusionPipeline, DDIMScheduler | ||
| from diffusers.utils import export_to_video | ||
| from PIL import Image | ||
|
|
||
| pipe = DiffusionPipeline.from_pretrained( | ||
| "Tune-A-Video-library/df-cpt-mo-di-bear-guitar", torch_dtype=torch.float16 | ||
| ).to("cuda") | ||
|
|
||
| prompt = "A princess playing a guitar, modern disney style" | ||
| generator = torch.Generator(device="cuda").manual_seed(42) | ||
|
|
||
| video_frames = pipe(prompt, video_length=3, generator=generator, num_inference_steps=50, output_type="np").frames | ||
|
|
||
| # Saving to gif. | ||
| pil_frames = [Image.fromarray(frame) for frame in video_frames] | ||
| duration = len(pil_frames) / 8 | ||
| pil_frames[0].save( | ||
| "animation.gif", | ||
| save_all=True, | ||
| append_images=pil_frames[1:], # append rest of the images | ||
| duration=duration * 1000, # in milliseconds | ||
| loop=0, | ||
| ) | ||
|
|
||
| # Saving to video | ||
| video_path = export_to_video(video_frames) | ||
| ``` | ||
|
|
||
| Here are some sample outputs: | ||
|
|
||
| <table> | ||
| <tr> | ||
| <td><center> | ||
| A princess playing a guitar, modern disney style | ||
| <br> | ||
| <img src="https://huggingface.co/Tune-A-Video-library/df-cpt-mo-di-bear-guitar/resolve/main/samples/princess.gif" | ||
| alt="A princess playing a guitar, modern disney style" | ||
| style="width: 300px;" /> | ||
| </center></td> | ||
| </tr> | ||
| </table> | ||
|
|
||
| ## Available checkpoints | ||
|
|
||
| * [Tune-A-Video-library/df-cpt-mo-di-bear-guitar](https://huggingface.co/Tune-A-Video-library/df-cpt-mo-di-bear-guitar) | ||
|
|
||
| ## TuneAVideoPipeline | ||
| [[autodoc]] TuneAVideoPipeline | ||
| - all | ||
| - __call__ |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -757,7 +757,196 @@ def forward(self, input_tensor, temb, scale: float = 1.0): | |||||||||
| return output_tensor | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # unet_rl.py | ||||||||||
| class Upsample3D(nn.Module): | ||||||||||
| """A 3D upsampling layer. Reshapes the input tensor to video like tensor, applies upsampling conv, | ||||||||||
| converts it back to the original shape. | ||||||||||
|
|
||||||||||
| Parameters: | ||||||||||
| channels (`int`): | ||||||||||
| number of channels in the inputs and outputs. | ||||||||||
| out_channels (`int`, optional): | ||||||||||
| number of output channels. Defaults to `channels`. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def __init__(self, channels, out_channels=None): | ||||||||||
| super().__init__() | ||||||||||
| self.channels = channels | ||||||||||
| self.out_channels = out_channels or channels | ||||||||||
|
|
||||||||||
| self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) | ||||||||||
|
|
||||||||||
| def forward(self, hidden_states, output_size=None): | ||||||||||
| if hidden_states.shape[1] != self.channels: | ||||||||||
| raise ValueError( | ||||||||||
| f"Expected hidden_states tensor at dimension 1 to match the number of channels. Expected: {self.channels} but passed: {hidden_states.shape[1]}" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 | ||||||||||
| dtype = hidden_states.dtype | ||||||||||
| if dtype == torch.bfloat16: | ||||||||||
| hidden_states = hidden_states.to(torch.float32) | ||||||||||
|
|
||||||||||
| # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | ||||||||||
| if hidden_states.shape[0] >= 64: | ||||||||||
| hidden_states = hidden_states.contiguous() | ||||||||||
|
|
||||||||||
| # if `output_size` is passed we force the interpolation output | ||||||||||
| # size and do not make use of `scale_factor=2` | ||||||||||
| if output_size is None: | ||||||||||
| hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") | ||||||||||
| else: | ||||||||||
| hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") | ||||||||||
|
|
||||||||||
| # If the input is bfloat16, we cast back to bfloat16 | ||||||||||
| if dtype == torch.bfloat16: | ||||||||||
| hidden_states = hidden_states.to(dtype) | ||||||||||
|
|
||||||||||
| # Inflate | ||||||||||
| video_length = hidden_states.shape[2] | ||||||||||
| # b c f h w -> (b f) c h w | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
| hidden_states = hidden_states.flatten(0, 1) | ||||||||||
|
|
||||||||||
| hidden_states = self.conv(hidden_states) | ||||||||||
| # Deflate | ||||||||||
| # (b f) c h w -> b c f h w (f=video_length) | ||||||||||
| hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
|
|
||||||||||
| return hidden_states | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class Downsample3D(nn.Module): | ||||||||||
| """A 3D downsampling layer. Reshapes the input tensor to video like tensor, applies conv, | ||||||||||
| converts it back to the original shape. | ||||||||||
|
|
||||||||||
| Parameters: | ||||||||||
| channels (`int`): | ||||||||||
| number of channels in the inputs and outputs. | ||||||||||
| out_channels (`int`, optional): | ||||||||||
| number of output channels. Defaults to `channels`. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def __init__(self, channels, out_channels=None, padding=1): | ||||||||||
| super().__init__() | ||||||||||
| self.channels = channels | ||||||||||
| self.out_channels = out_channels or channels | ||||||||||
| self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=padding) | ||||||||||
|
|
||||||||||
| def forward(self, hidden_states): | ||||||||||
| video_length = hidden_states.shape[2] | ||||||||||
| # b c f h w -> (b f) c h w | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
| hidden_states = hidden_states.flatten(0, 1) | ||||||||||
| # Conv | ||||||||||
| hidden_states = self.conv(hidden_states) | ||||||||||
| # (b f) c h w -> b c f h w (f=video_length) | ||||||||||
| hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
|
|
||||||||||
| return hidden_states | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ResnetBlock3D(nn.Module): | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing docstrings.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the main difference to the exsiting resnet class here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ResnetBlock3D doesn't support AdaGroupNorm, only torch.nn.GroupNorm; the convolution used is InflatedConv3d instead of torch.nn.conv2d. Didn't want to merge it with the 2D block, because then ResnetBlock2D also has to have additional parameters. If you're ok with that then I can add some parameters to make it flexible.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this better as it encourages disentanglement between the modules and separation of concerns. |
||||||||||
| r""" | ||||||||||
| A Resnet block. Used specifically for video like data. | ||||||||||
|
|
||||||||||
| Parameters: | ||||||||||
| in_channels (`int`): The number of channels in the input. | ||||||||||
| out_channels (`int`, *optional*, default to be `None`): | ||||||||||
| The number of output channels for the first conv2d layer. If None, same as `in_channels`. | ||||||||||
| dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. | ||||||||||
| temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. | ||||||||||
| groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. | ||||||||||
| eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. | ||||||||||
| non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. | ||||||||||
| output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def __init__( | ||||||||||
| self, | ||||||||||
| *, | ||||||||||
| in_channels, | ||||||||||
|
Comment on lines
+868
to
+869
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
no? |
||||||||||
| out_channels=None, | ||||||||||
| dropout=0.0, | ||||||||||
| temb_channels=512, | ||||||||||
| groups=32, | ||||||||||
| eps=1e-6, | ||||||||||
| non_linearity="swish", | ||||||||||
| output_scale_factor=1.0, | ||||||||||
| ): | ||||||||||
| super().__init__() | ||||||||||
| self.in_channels = in_channels | ||||||||||
| out_channels = in_channels if out_channels is None else out_channels | ||||||||||
| self.out_channels = out_channels | ||||||||||
| self.output_scale_factor = output_scale_factor | ||||||||||
|
|
||||||||||
| self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) | ||||||||||
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | ||||||||||
|
|
||||||||||
| self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) | ||||||||||
|
|
||||||||||
| self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) | ||||||||||
| self.dropout = torch.nn.Dropout(dropout) | ||||||||||
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | ||||||||||
|
|
||||||||||
| self.nonlinearity = get_activation(non_linearity) | ||||||||||
|
|
||||||||||
| self.use_in_shortcut = self.in_channels != self.out_channels | ||||||||||
| self.conv_shortcut = None | ||||||||||
| if self.use_in_shortcut: | ||||||||||
| self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | ||||||||||
|
|
||||||||||
| def forward(self, input_tensor, temb): | ||||||||||
| hidden_states = input_tensor | ||||||||||
|
|
||||||||||
| hidden_states = self.norm1(hidden_states) | ||||||||||
| hidden_states = self.nonlinearity(hidden_states) | ||||||||||
|
|
||||||||||
| video_length = hidden_states.shape[2] | ||||||||||
| # b c f h w -> (b f) c h w | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
| hidden_states = hidden_states.flatten(0, 1) | ||||||||||
| hidden_states = self.conv1(hidden_states) | ||||||||||
| # (b f) c h w -> b c f h w (f=video_length | ||||||||||
| hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
|
|
||||||||||
| if temb is not None: | ||||||||||
| temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] | ||||||||||
|
|
||||||||||
| hidden_states = hidden_states + temb | ||||||||||
|
|
||||||||||
| hidden_states = self.norm2(hidden_states) | ||||||||||
|
|
||||||||||
| hidden_states = self.nonlinearity(hidden_states) | ||||||||||
|
|
||||||||||
| hidden_states = self.dropout(hidden_states) | ||||||||||
|
|
||||||||||
| video_length = hidden_states.shape[2] | ||||||||||
| # b c f h w -> (b f) c h w | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
| hidden_states = hidden_states.flatten(0, 1) | ||||||||||
| hidden_states = self.conv2(hidden_states) | ||||||||||
| # (b f) c h w -> b c f h w (f=video_length) | ||||||||||
| hidden_states = hidden_states.reshape([-1, video_length, *hidden_states.shape[1:]]) | ||||||||||
| hidden_states = hidden_states.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
|
|
||||||||||
| if self.conv_shortcut is not None: | ||||||||||
| video_length = input_tensor.shape[2] | ||||||||||
| # "b c f h w -> (b f) c h w" | ||||||||||
| input_tensor = input_tensor.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
| input_tensor = input_tensor.flatten(0, 1) | ||||||||||
| input_tensor = self.conv_shortcut(input_tensor) | ||||||||||
| # "(b f) c h w -> b c f h w"; f=video_length | ||||||||||
| input_tensor = input_tensor.reshape([-1, video_length, *input_tensor.shape[1:]]) | ||||||||||
| input_tensor = input_tensor.movedim((0, 1, 2, 3, 4), (0, 2, 1, 3, 4)) | ||||||||||
|
|
||||||||||
| output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | ||||||||||
|
|
||||||||||
| return output_tensor | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: | ||||||||||
| if len(tensor.shape) == 2: | ||||||||||
| return tensor[:, :, None] | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstrings.