-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Enabling gradient checkpointing for VAE #2536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling gradient checkpointing for VAE #2536
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to only add gradient checkpointing for the blocks - only there we can really save memory. It doesn't help much to add it to individual layers, such as self.conv_in
@williamberman could you also have a look?
Got it. Thanks for the feedback. |
examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
Outdated
Show resolved
Hide resolved
examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
Show resolved
Hide resolved
| if args.gradient_checkpointing: | ||
| unet.enable_gradient_checkpointing() | ||
| vae.enable_gradient_checkpointing() | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that gradient checkpointing doesn't necessarily mean that the model is training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How exactly is fine-tuning the VAE done in train_text_to_image?
It seems like the issue #1883 wants to enable gradient checkpointing specifically for training VAE independently so should we add in some sort of argument?
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks in a much better shape now, thanks!
Could you try to revert the reformatting of unrelated functions and then let's make sure to always freeze the unet and vae of train_text_to_image just like before. Note that gradient checkpointing still helps even if the weights are frozen :-)
Good job! Think we're on the final stretch now
Thanks for the feedback! I really appreciate it. Apologies on the formatting. I'm using black and it seems to really want these functions formatted this way. I might have missed this but is there a suggested linter/formatter to use? |
5056d87 to
b7471a0
Compare
|
Fixed the issue with formatting. It was because of issues with line endings being CRLF. However, some of the code formatting is still because of black (formatting on save). tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]becomes tile = z[
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
] |
| if args.gradient_checkpointing: | ||
| unet.enable_gradient_checkpointing() | ||
| vae.enable_gradient_checkpointing() | ||
| vae.requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| vae.requires_grad_(True) |
Think even without requires_grad_ this should help as gradients are passed down
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool this now looks good to me!
Last thing we should do is to add a simple test that makes sure that gradient checkpointing works. Maybe you can get some inspiration from:
| def test_gradient_checkpointing(self): |
and add a similar test here:
https://github.com/huggingface/diffusers/blob/main/tests/models/test_models_vae.py
After that we can merge I think :-)
…t_to_image.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
Show resolved
Hide resolved
|
wow super cool :) I learned how our gradient checkpointing works from this PR 😁 will approve once test is written |
|
Seem to be failing CI/CD tests for MacOS. Anyway to fix this atm? |
|
macos tests are unrelated, looks great |
|
Good job @pie31415 |
* updated black format * update black format * make style format * updated line endings * update code formatting * Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * added vae gradient checkpointing test * make style --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com>
* updated black format * update black format * make style format * updated line endings * update code formatting * Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * added vae gradient checkpointing test * make style --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com>
* updated black format * update black format * make style format * updated line endings * update code formatting * Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * added vae gradient checkpointing test * make style --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com>
PR for issue #1883