-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[Utils] Adds store() and restore() methods to EMAModel
#2302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me! Let's just be a bit careful when adding this functionality to the training scripts to not blow up GPU memory
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
|
For the code quality tests, I am running
My formatting virtual environment is up-to-date as well:
|
|
Still the same issue: #2302 (comment) |
|
Fixed the We also use https://github.com/huggingface/doc-builder to improve the style of our code (think it makes sure comments are not too long. You can install it with: |
src/diffusers/training_utils.py
Outdated
| parameters = list(parameters) | ||
| self.collected_params = [param.detach().cpu().clone() for param in parameters] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just have the caller do the detach.cpu.clone and save collected_params in a variable themselves? This sort of adhoc storage just between two methods is a bit of an OO antipattern imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing is I think the initial list() is also unnecessary. If something is an Iterable, I think we should be able to directly iterate over it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is true. The detach code was after, hence.
src/diffusers/training_utils.py
Outdated
| if self.collected_params is None: | ||
| raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do https://github.com/huggingface/diffusers/pull/2302/files#r1107495110, then this method wouldn't need to reference self and it wouldn't have this error condition. This would mean maybe it would look better as a function, and then store and restore could be free-standing functions not on a class.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the EMAModel class majorly going to be used in the examples, which are meant to be easier to work with, I like this design approach better. So, with that in mind, for #2302 (comment), I think it's probably better if we do it ourselves.
From what I have seen is that the community is fairly eager to customize things as per their needs, but with that, we should aim to provide example scripts that are easy as starting points.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can still make the methods easy to work with without the adhoc variable storage. I.e.
# training utils
def store(params):
return [x.cpu().detach() for x in params]
def restore(params, stored_params):
for c_param, param in zip(stored_params, parameters):
param.data.copy_(c_param.data)
# training script
stored = store(params)
# do something
restore(params, stored)is still friendly and hackable without the adhoc variable storage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then it is coming down to a preference now :D
I like the class approach better than what you showed above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point @williamberman, but IMO from a user's perspective, it would be better to just call class methods than having to manually do this.
williamberman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these shouldn't be methods on the class but ok with merging if other people feel strongly it should be :)
Co-authored-by: Will Berman <wlbberman@gmail.com>
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand @williamberman points here and to be honest the EMAModel class is also not super intuitive to me in general. To me it's actually more of a EMAModelTrainingUtils class.
But from a user point of view that uses the examples training scripts I also think it's more convenient to be able to attach temporarily stored parameters to the EMAModel instead of having some stored list flying around
Maybe some renaming could help here a bit to make it super clear that the EMAModel.store(...) function actually does not store and restore ema parameters but the non_ema parameters?
|
Also good to go for me, but would be happy about potentially renaming the functions a bit to make the difference between the actual EMA and the non-EMA weights clearer |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR. I agree with Patrick regarding naming.
I just left a comment about copying the params CPU in store. Not sure if it's really needed, will need to dig into it a bit. The rest of it looks great!
src/diffusers/training_utils.py
Outdated
| if self.collected_params is None: | ||
| raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point @williamberman, but IMO from a user's perspective, it would be better to just call class methods than having to manually do this.
|
I am in agreement regarding the name changes. Applying them now. I will let @patil-suraj play around with the CPU offloading ( |
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
… into refactor/ema
|
We could do this in another PR, but it would be nice to add some tests for ema model. @sayakpaul feel free to work on it if you want. |
I can but in a separate PR. |
Co-authored-by: patil-suraj <surajp815@gmail.com>
|
@patil-suraj see if the latest changes are good to go. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for updating this quickly. Just left two quick comments, we can remove the temp params from saving/loading state_dict. Should be good to go after that!
|
I think the error related to running |
|
The |
Right. Sorry for missing it out. Should be good now. |
|
Thanks! The failing test is unrelated, merging this now:) |
…ce#2302) * add store and restore() methods to EMAModel. * Update src/diffusers/training_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style with doc builder * remove explicit listing. * Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * chore: better variable naming. * better treatment of temp_stored_params Co-authored-by: patil-suraj <surajp815@gmail.com> * make style * remove temporary params from earth 🌎 * make fix-copies. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
…ce#2302) * add store and restore() methods to EMAModel. * Update src/diffusers/training_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style with doc builder * remove explicit listing. * Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * chore: better variable naming. * better treatment of temp_stored_params Co-authored-by: patil-suraj <surajp815@gmail.com> * make style * remove temporary params from earth 🌎 * make fix-copies. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
…ce#2302) * add store and restore() methods to EMAModel. * Update src/diffusers/training_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style with doc builder * remove explicit listing. * Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * chore: better variable naming. * better treatment of temp_stored_params Co-authored-by: patil-suraj <surajp815@gmail.com> * make style * remove temporary params from earth 🌎 * make fix-copies. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
This PR adds the
store()andrestore()methods to ourEMAModelclass. The reason is beautifully explained by @patil-suraj here.An end-to-end example of how to use these methods in practice is here.
It also un-blocks #2157