Skip to content

Conversation

@aandyw
Copy link
Contributor

@aandyw aandyw commented Mar 2, 2023

PR for issue #1883

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 2, 2023

The documentation is not available anymore as the PR was closed or merged.

@aandyw aandyw changed the title Enabling gradient checkpointing for VAE [WIP] Enabling gradient checkpointing for VAE Mar 2, 2023
@aandyw aandyw changed the title [WIP] Enabling gradient checkpointing for VAE Enabling gradient checkpointing for VAE Mar 2, 2023
@aandyw aandyw marked this pull request as ready for review March 2, 2023 23:20
@aandyw aandyw changed the title Enabling gradient checkpointing for VAE [WIP] Enabling gradient checkpointing for VAE Mar 3, 2023
@aandyw aandyw changed the title [WIP] Enabling gradient checkpointing for VAE Enabling gradient checkpointing for VAE Mar 3, 2023
@aandyw aandyw changed the title Enabling gradient checkpointing for VAE [WIP] Enabling gradient checkpointing for VAE Mar 3, 2023
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to only add gradient checkpointing for the blocks - only there we can really save memory. It doesn't help much to add it to individual layers, such as self.conv_in

@williamberman could you also have a look?

@aandyw
Copy link
Contributor Author

aandyw commented Mar 8, 2023

Let's try to only add gradient checkpointing for the blocks - only there we can really save memory. It doesn't help much to add it to individual layers, such as self.conv_in

@williamberman could you also have a look?

Got it. Thanks for the feedback.

@aandyw aandyw requested review from patrickvonplaten and removed request for williamberman March 10, 2023 16:04
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
vae.enable_gradient_checkpointing()
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that gradient checkpointing doesn't necessarily mean that the model is training

Copy link
Contributor Author

@aandyw aandyw Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How exactly is fine-tuning the VAE done in train_text_to_image?

It seems like the issue #1883 wants to enable gradient checkpointing specifically for training VAE independently so should we add in some sort of argument?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks in a much better shape now, thanks!

Could you try to revert the reformatting of unrelated functions and then let's make sure to always freeze the unet and vae of train_text_to_image just like before. Note that gradient checkpointing still helps even if the weights are frozen :-)

Good job! Think we're on the final stretch now

@aandyw
Copy link
Contributor Author

aandyw commented Mar 13, 2023

The PR looks in a much better shape now, thanks!

Could you try to revert the reformatting of unrelated functions and then let's make sure to always freeze the unet and vae of train_text_to_image just like before. Note that gradient checkpointing still helps even if the weights are frozen :-)

Good job! Think we're on the final stretch now

Thanks for the feedback! I really appreciate it.

Apologies on the formatting. I'm using black and it seems to really want these functions formatted this way. I might have missed this but is there a suggested linter/formatter to use?

@patrickvonplaten
Copy link
Contributor

@pie31415 could you try to use the newest black version, i.e.:

"black~=23.1",
- does this correspond to the version you have? :-)

@aandyw aandyw force-pushed the gradient-checkpointing-vae branch from 5056d87 to b7471a0 Compare March 13, 2023 22:40
@aandyw
Copy link
Contributor Author

aandyw commented Mar 13, 2023

Fixed the issue with formatting. It was because of issues with line endings being CRLF.

However, some of the code formatting is still because of black (formatting on save).
e.g.

tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]

becomes

tile = z[
  :,
  :,
  i : i + self.tile_latent_min_size,
  j : j + self.tile_latent_min_size,
 ]

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
vae.enable_gradient_checkpointing()
vae.requires_grad_(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
vae.requires_grad_(True)

Think even without requires_grad_ this should help as gradients are passed down

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool this now looks good to me!

Last thing we should do is to add a simple test that makes sure that gradient checkpointing works. Maybe you can get some inspiration from:

def test_gradient_checkpointing(self):

and add a similar test here:
https://github.com/huggingface/diffusers/blob/main/tests/models/test_models_vae.py

After that we can merge I think :-)

aandyw and others added 4 commits March 14, 2023 14:13
…t_to_image.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@williamberman
Copy link
Contributor

wow super cool :) I learned how our gradient checkpointing works from this PR 😁 will approve once test is written

@williamberman williamberman self-requested a review March 16, 2023 08:54
@aandyw
Copy link
Contributor Author

aandyw commented Mar 16, 2023

Seem to be failing CI/CD tests for MacOS. Anyway to fix this atm?

@williamberman
Copy link
Contributor

macos tests are unrelated, looks great

@williamberman williamberman changed the title [WIP] Enabling gradient checkpointing for VAE Enabling gradient checkpointing for VAE Mar 17, 2023
@williamberman williamberman merged commit 116f70c into huggingface:main Mar 17, 2023
@patrickvonplaten
Copy link
Contributor

Good job @pie31415

w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* updated black format

* update black format

* make style format

* updated line endings

* update code formatting

* Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* added vae gradient checkpointing test

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* updated black format

* update black format

* make style format

* updated line endings

* update code formatting

* Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* added vae gradient checkpointing test

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* updated black format

* update black format

* make style format

* updated line endings

* update code formatting

* Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/models/vae.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* added vae gradient checkpointing test

* make style

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants