Skip to content

Conversation

@Abhinay1997
Copy link
Contributor

@Abhinay1997 Abhinay1997 commented Feb 22, 2023

This PR adds the Tune-A-Video pipeline built on diffusers at https://github.com/showlab/Tune-A-Video to be part of the diffusers pipelines.

See discussion at: #2432

Code Sample:

import torch
import torchvision
import os
import numpy as np
import imageio
from einops import rearrange

from diffusers import UNet3DConditionModel, TuneAVideoPipeline, DDIMScheduler

def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = (x * 255).numpy().astype(np.uint8)
        outputs.append(x)

    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, fps=fps)

pretrained_model_path = "nitrosocke/mo-di-diffusion"
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
unet = UNet3DConditionModel.from_pretrained("NagaSaiAbhinay/tune-a-video-mo-di-bear-guitar-v1", subfolder='unet', torch_dtype=torch.float16).to('cuda')
pipe = TuneAVideoPipeline.from_pretrained(pretrained_model_path, unet=unet, scheduler= scheduler, torch_dtype=torch.float16).to("cuda")

prompt = "a magical princess is playing guitar, modern disney style"
video = pipe(prompt, video_length=4, height=256, width=256, num_inference_steps=38, guidance_scale=12.5, output_type="pt").frames

#To use export_to_video from diffusers.utils use the below line by uncommenting it.
#video = pipe(prompt, video_length=4, height=256, width=256, num_inference_steps=38, guidance_scale=12.5, output_type="np").frames

save_path = './out.gif'
save_videos_grid(video, save_path)

TODOs

  • Tests for the new UNet3DConditionModel
  • Tests for the tune a video pipeline
  • Host the pretrained UNet3DConditionModel weights under hf.co/showlab
  • Documentation (docstrings, .mdx documentation of the pipeline, update models.mdx to include UNet3DConditionModel, update outputs.mdx to update the new video pipeline output)
  • Make TuneAVideoPipeline inherit from TextualInversionLoaderMixIn.
  • Ensure that dummy objects are added for TuneAVideoPipeline under dummy_torch_and_transformers.py
  • Move Transformer3DModel to its own module
  • Make TuneAVideo Output compatible with Text2Video Output so we can use the export_to_video utillity

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sayakpaul
Copy link
Member

Thanks so much, @Abhinay1997! This is quick.

W.r.t #2432 (comment)

We'll be replacing einops.rearrange, einops.repeat with equivalent torch operations where applicable, correct ? I see einops being commented out in other files too. Just wanted to confirm.

Yeah, let's replace with native PyTorch ops.

The BasicTransformerBlock in TuneAVideo's attention.py uses SparseCausalAttention. But diffusers already has a BasicTransformerBlock with CrossAttention. Can we make the BasicTransfomerBlock configurable then ? For now I am calling it BasicSparseTransformerBlock to differentiate.

I think we'd want to keep BasicTransfomerBlock as is for now. Differentiating in between BasicTransfomerBlock and BasicSparseTransformerBlock would be a better option IMO.

@Abhinay1997
Copy link
Contributor Author

Thank you @sayakpaul :)

Got it. Will update to use torch ops.

Ok. Will retain the BasicSparseTransformerBlock.

@Abhinay1997
Copy link
Contributor Author

@sayakpaul just realized that the original code was built on an older CrossAttention block, Making changes to get it to work with the current CrossAttention module

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 7, 2023

cc @williamberman @yiyixuxu @sayakpaul could you help here?

@sayakpaul
Copy link
Member

cc @williamberman @yiyixuxu @sayakpaul could you help here?

I will take the lead here. If need be I will ping you, folks! Easy till then.

@Abhinay1997
Copy link
Contributor Author

Sorry everyone.. Work has been hectic and I couldn't find time to fix this. I'll be working on it today :)

@sayakpaul
Copy link
Member

No problem at all. We deeply appreciate your contributions :)

@Abhinay1997
Copy link
Contributor Author

So, I compared my output with the original code output with the same inputs where I was getting noise. It's markedly different. Even when the original code returns noise, there is temporal coherence.

So for tomorrow, I'll be comparing both the UNets with same checkpoints and a fixed seed random input. Hopefully they won't match and I can drill down to the smallest failing component.

@sayakpaul
Copy link
Member

So for tomorrow, I'll be comparing both the UNets with same checkpoints and a fixed seed random input. Hopefully they won't match and I can drill down to the smallest failing component.

In these cases, I usually do the following:

  • Sort the individual model components w.r.t my hunch as to where there might be differences in the implementation details.
  • I then test those components in isolation. I generate dummy inputs and verify if the outputs from the original impl. and my impl. match for the underlying component.

Maybe we can follow something similar here? For starters, I'd begin by inspecting the einops that were replaced by native Torch ops.

@Abhinay1997
Copy link
Contributor Author

@sayakpaul , Thank you for the input. Actually I tested the einops equivalence before I replaced the original code. See: https://colab.research.google.com/drive/1BG1b-YVNUAsy9OBEUS7cYe843k14jwCI?usp=sharing

So I'll start with the CrossAttention block and then go to the einops equivalents instead of directly comparing the UNets themselves.

@Abhinay1997
Copy link
Contributor Author

Done with suggested changes @patrickvonplaten.

Comment on lines +777 to +778
*,
in_channels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*,
in_channels,
in_channels,
*,

no?

@Abhinay1997
Copy link
Contributor Author

@patrickvonplaten retained comments like # "b c f h w -> (b f) c h w" to help with readability. Would you prefer not having them at all ?

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten retained comments like # "b c f h w -> (b f) c h w" to help with readability. Would you prefer not having them at all ?

No that works for me!

@patrickvonplaten patrickvonplaten requested a review from DN6 October 4, 2023 09:59
@patrickvonplaten
Copy link
Contributor

@DN6 can you also review this in-detail especially with respect to AnimateDiff?

* src/diffusers/models/resnet.py -> type hints for rearrange_dims
* src/diffusers/models/unet_3d_blocks.py -> New apply_freeu changes
* src/diffusers/models/unet_3d_condition.py -> Updated doc string [Still needs checking]
@Abhinay1997
Copy link
Contributor Author

Resolved merge conflicts. Can someone approve the test workflow ?

@Abhinay1997
Copy link
Contributor Author

Thanks Sayak!

Team, just an fyi, the failing tests are unrelated to PR.

@patrickvonplaten
Copy link
Contributor

@DN6 can you also review this in-detail especially with respect to AnimateDiff?

@DN6 could you check this PR as well?

def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet3DConditionModel(
block_out_channels=(32, 64, 64, 64),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to lower the number of channels here so that the tests run faster?

Comment on lines +172 to +181
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
pass

@unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor")
def test_save_load_local(self, expected_max_difference=5e-4):
pass

@unittest.skip(reason="`set_default_attn_processor` is not supported as we use a custom attn processor")
def test_save_load_optional_components(self, expected_max_difference=1e-4):
pass
Copy link
Collaborator

@DN6 DN6 Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rather than using the inherited tests, we should implement these tests from scratch here. We can omit setting the default processor.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of small requests related to tests, but otherwise LGTM 👍🏽

Abhinay1997 and others added 2 commits October 24, 2023 11:17
@Abhinay1997
Copy link
Contributor Author

Will add the changes to the tests soon !

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Abhinay1997
Copy link
Contributor Author

Apologies for the delay, need to resolve the merge conflict in unet blocks and reduce the test runtime. Didn't get a chance to work on this over the last few weeks.

@patrickvonplaten
Copy link
Contributor

Hey @Abhinay1997,

I'm super sorry, but I think the codebase is changing too fast to keep up with everything. In order to quickly finish the PR could we maybe move everything into the research folder: https://github.com/patrickvonplaten/diffusers/tree/c27e30dcb5eaaabc30f3dbc48587ad52ee345b79/examples/research_projects ? There are a couple powerful video models now out there (such as SVD) and I'm not sure it still makes sense to pursue with this PR as a core diffusers integraiton. I do think however that it would be very valuable as a contribution to research_projects

@Abhinay1997
Copy link
Contributor Author

@patrickvonplaten, agree with you about its relevance to core diffusers post SVD. I'll move it to the research_projects section as recommended.

Thank you for your patience on this. I wasn't planning to drag this for so long. I'll pick this back up as a new PR once the unicontrol PR is merged.

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten, agree with you about its relevance to core diffusers post SVD. I'll move it to the research_projects section as recommended.

Thank you for your patience on this. I wasn't planning to drag this for so long. I'll pick this back up as a new PR once the unicontrol PR is merged.

Thanks!

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jan 9, 2024
@toastymcvoid
Copy link

Any updates on this? Would still love to see this done.

@sayakpaul
Copy link
Member

We decided to de-priotize it in light of the other video pipelines we have in the library.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Good Example PR stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants