-
Notifications
You must be signed in to change notification settings - Fork 6.7k
add: logging to text2image. #2173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Adding the safety checker explicitly (f2a143f) also didn't help. |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR! Looks good to me, just left a comment about loading text_encoder and vae again for inference, not really in favour of that.
Also, when doing mixed-precision training, I would expect the generation to be in mixed-precision as well. And for that we'll probably need to use torch.autocast for the reasons explained here #2163 (comment)
We don't really want to promote autocast for inference, but I don't see any other clean way of handling it here. We could explain it with a comment on why it's used in the script and why it's not needed for general inference.
Also, this will be enabled only when using mixed-precision, else everything will default to fp32. This can be achieved using
with torch.autocast(accelerator.device, enabled=args.mixed_precision == "fp16"):
....
Happy to hear suggestions :)
|
|
||
| if args.report_to == "wandb": | ||
| if not is_wandb_available(): | ||
| raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| raise ImportError("Make sure to install wandb if you want to use it for logging during training.") | |
| raise ImportError("Make sure to install wandb if you want to use it for logging during training. You can do so by doing `pip install wandb`") |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( | ||
| args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_revision | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StableDiffusionPipeline.from_pretrained should automatically load safety_checker when available, is there any reason we need to load it here explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I'm not sure what difference it makes when we load safety_checker separately like this, StableDiffusionPipeline.from_pretrained does pretty much the same thing.
| safety_checker=safety_checker, | ||
| revision=args.revision, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also directly pass vae and text_encoder here, not really in favour of loading them again, as this would take more memory and time and might also lead to OOM (depending on the GPU).
| # safety_checker.to(accelerator.device, dtype=weight_dtype) | ||
| pipeline = StableDiffusionPipeline.from_pretrained( | ||
| args.pretrained_model_name_or_path, | ||
| unet=accelerator.unwrap_model(unet), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When doing ema, we should use ema weights for inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to do that, we'll need to
- temporarily store the
non-emaweights - copy the
emaweights tounet - restore the
non-emaweight back in theunet.
For that we'll need to add the store, restore method in EMAModel as defined in https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L139
Happy to take care of this if you want :)
|
@patil-suraj thanks a lot for your comments. Let me address them in some time (might take tomorrow mid half) as I have one other commitment and will be travelling for some time. I hope it's not a blocker. |
|
No worries at all, not super urgent anyway : ) |
|
Actually, there's already a similar PR #2157 |
Sure, I can close this one then. |
|
Closing in favor of #2157 |
|
Thanks a lot! |
Context: #2163
Potentially closes #2163
Command to fire training:
Leads to:
The PR also follows this comment: #2163 (comment)
It's the safety checker that causes the problem.