Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 9, 2023

This PR adds the store() and restore() methods to our EMAModel class. The reason is beautifully explained by @patil-suraj here.

An end-to-end example of how to use these methods in practice is here.

It also un-blocks #2157

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 9, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me! Let's just be a bit careful when adding this functionality to the training scripts to not blow up GPU memory

sayakpaul and others added 2 commits February 13, 2023 20:09
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@sayakpaul
Copy link
Member Author

For the code quality tests, I am running make style and make quality. It picks up a script unrelated to this PR:

  • src/diffusers/training_utils.py

My formatting virtual environment is up-to-date as well:

  • black: 23.1.0
  • isort: 5.12.0
  • ruff: 0.0.244

@sayakpaul
Copy link
Member Author

Still the same issue: #2302 (comment)

@patrickvonplaten
Copy link
Contributor

Fixed the make quality issue. @sayakpaul, your hf-doc-builder library might be outdated.

We also use https://github.com/huggingface/doc-builder to improve the style of our code (think it makes sure comments are not too long.

You can install it with:

pip install hf-doc-builder

Comment on lines 261 to 262
parameters = list(parameters)
self.collected_params = [param.detach().cpu().clone() for param in parameters]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just have the caller do the detach.cpu.clone and save collected_params in a variable themselves? This sort of adhoc storage just between two methods is a bit of an OO antipattern imo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing is I think the initial list() is also unnecessary. If something is an Iterable, I think we should be able to directly iterate over it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is true. The detach code was after, hence.

Comment on lines 274 to 275
if self.collected_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do https://github.com/huggingface/diffusers/pull/2302/files#r1107495110, then this method wouldn't need to reference self and it wouldn't have this error condition. This would mean maybe it would look better as a function, and then store and restore could be free-standing functions not on a class.

WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the EMAModel class majorly going to be used in the examples, which are meant to be easier to work with, I like this design approach better. So, with that in mind, for #2302 (comment), I think it's probably better if we do it ourselves.

From what I have seen is that the community is fairly eager to customize things as per their needs, but with that, we should aim to provide example scripts that are easy as starting points.

WDYT?

Copy link
Contributor

@williamberman williamberman Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still make the methods easy to work with without the adhoc variable storage. I.e.

# training utils
def store(params):
    return [x.cpu().detach() for x in params]

def restore(params, stored_params):
    for c_param, param in zip(stored_params, parameters):
            param.data.copy_(c_param.data)


# training script

stored = store(params)
# do something
restore(params, stored)

is still friendly and hackable without the adhoc variable storage

Copy link
Member Author

@sayakpaul sayakpaul Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it is coming down to a preference now :D

I like the class approach better than what you showed above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point @williamberman, but IMO from a user's perspective, it would be better to just call class methods than having to manually do this.

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these shouldn't be methods on the class but ok with merging if other people feel strongly it should be :)

Co-authored-by: Will Berman <wlbberman@gmail.com>
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand @williamberman points here and to be honest the EMAModel class is also not super intuitive to me in general. To me it's actually more of a EMAModelTrainingUtils class.

But from a user point of view that uses the examples training scripts I also think it's more convenient to be able to attach temporarily stored parameters to the EMAModel instead of having some stored list flying around

Maybe some renaming could help here a bit to make it super clear that the EMAModel.store(...) function actually does not store and restore ema parameters but the non_ema parameters?

@patrickvonplaten
Copy link
Contributor

Also good to go for me, but would be happy about potentially renaming the functions a bit to make the difference between the actual EMA and the non-EMA weights clearer

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR. I agree with Patrick regarding naming.
I just left a comment about copying the params CPU in store. Not sure if it's really needed, will need to dig into it a bit. The rest of it looks great!

Comment on lines 274 to 275
if self.collected_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point @williamberman, but IMO from a user's perspective, it would be better to just call class methods than having to manually do this.

@sayakpaul
Copy link
Member Author

I am in agreement regarding the name changes. Applying them now. I will let @patil-suraj play around with the CPU offloading (.cpu()) before then.

@patil-suraj
Copy link
Contributor

We could do this in another PR, but it would be nice to add some tests for ema model. @sayakpaul feel free to work on it if you want.

@sayakpaul
Copy link
Member Author

We could do this in another PR, but it would be nice to add some tests for ema model. @sayakpaul feel free to work on it if you want.

I can but in a separate PR.

sayakpaul and others added 2 commits February 16, 2023 17:24
Co-authored-by: patil-suraj <surajp815@gmail.com>
@sayakpaul
Copy link
Member Author

@patil-suraj see if the latest changes are good to go.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for updating this quickly. Just left two quick comments, we can remove the temp params from saving/loading state_dict. Should be good to go after that!

@sayakpaul
Copy link
Member Author

I think the error related to running fix-copies is very likely from an unrelated issue.

@patil-suraj
Copy link
Contributor

The EMAModel is in the main init, and hence it's also in the dummy objects (https://github.com/huggingface/diffusers/blob/main/src/diffusers/utils/dummy_pt_objects.py#L618). Whenever we update any object in the main init we need to run make fix-copies.

@sayakpaul
Copy link
Member Author

The EMAModel is in the main init, and hence it's also in the dummy objects (https://github.com/huggingface/diffusers/blob/main/src/diffusers/utils/dummy_pt_objects.py#L618). Whenever we update any object in the main init we need to run make fix-copies.

Right. Sorry for missing it out. Should be good now.

@patil-suraj
Copy link
Contributor

Thanks! The failing test is unrelated, merging this now:)

@patil-suraj patil-suraj merged commit 6eaebe8 into main Feb 16, 2023
@patil-suraj patil-suraj deleted the refactor/ema branch February 16, 2023 14:20
mengfei25 pushed a commit to mengfei25/diffusers that referenced this pull request Mar 27, 2023
…ce#2302)

* add store and restore() methods to EMAModel.

* Update src/diffusers/training_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style with doc builder

* remove explicit listing.

* Apply suggestions from code review

Co-authored-by: Will Berman <wlbberman@gmail.com>

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* chore: better variable naming.

* better treatment of temp_stored_params

Co-authored-by: patil-suraj <surajp815@gmail.com>

* make style

* remove temporary params from earth 🌎

* make fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…ce#2302)

* add store and restore() methods to EMAModel.

* Update src/diffusers/training_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style with doc builder

* remove explicit listing.

* Apply suggestions from code review

Co-authored-by: Will Berman <wlbberman@gmail.com>

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* chore: better variable naming.

* better treatment of temp_stored_params

Co-authored-by: patil-suraj <surajp815@gmail.com>

* make style

* remove temporary params from earth 🌎

* make fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…ce#2302)

* add store and restore() methods to EMAModel.

* Update src/diffusers/training_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style with doc builder

* remove explicit listing.

* Apply suggestions from code review

Co-authored-by: Will Berman <wlbberman@gmail.com>

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* chore: better variable naming.

* better treatment of temp_stored_params

Co-authored-by: patil-suraj <surajp815@gmail.com>

* make style

* remove temporary params from earth 🌎

* make fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants