-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[Pipelines] Adds pix2pix zero #2334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
|
Wow, this is super cool. I certainly did a double take when I saw the futzing with gradients inside the inference pipeline. Not necessary but if you have time, a quick blurb in the doc string on why pix2pix zero uses gradients during inference would be really cool/helpful for future readers |
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
|
Great work! I'm in favor of not passing a path to the call function but instead force the user to load the tensor before hand and just pass a torch tensor. Apart from this it looks great, thanks for iterating so quickly :-) |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Thanks a lot for adding this pipeline so quickly. Agree with Patrick here that we should just pass the embeds directly instead of path.
I left some comments, let's make sure that pipeline works in fp16, some computation here might happen in fp32 and we don't cast it explicitly to the dtype of params. Would be cool to handle this.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
| pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True) | ||
| module.requires_grad_(True) | ||
| else: | ||
| pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
Outdated
Show resolved
Hide resolved
| text_encoder: CLIPTextModel, | ||
| tokenizer: CLIPTokenizer, | ||
| unet: UNet2DConditionModel, | ||
| scheduler: KarrasDiffusionSchedulers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the example, we are using only DDIM. Does it work with all schedulers? If not we should change type annotation and mention it in the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested it with the following ones (as can be seen in the test script except for DDPM):
- DDIM (default)
- EulerAncestralDiscreteScheduler
- LMSDiscreteScheduler
- DDPMScheduler
I will change the type annotation.
| logger.info("Loading caption generator since `conditions_input_image` is True.") | ||
| checkpoint = "Salesforce/blip-image-captioning-base" | ||
| captioner_processor = AutoProcessor.from_pretrained(checkpoint) | ||
| captioner = BlipForConditionalGeneration.from_pretrained(checkpoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make sure that captioner has same weight dtype as of other models.
| # 2. Generate a caption for the input image if we are conditioning the | ||
| # pipeline based on some input image. | ||
| if self.conditions_input_image: | ||
| caption, preprocessed_image = generate_caption(image, self._captioner, self._captioner_processor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead passing the model and processor here, can it become it become the class method ?
| callback(i, t, latents) | ||
|
|
||
| # 8. Compute the edit directions. | ||
| edit_direction = construct_direction(source_embedding_path, target_embedding_path).to(prompt_embeds.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe make construct_direction a class method
| callback(i, t, latents) | ||
|
|
||
| # 8. Compute the edit directions. | ||
| edit_direction = construct_direction(source_embedding_path, target_embedding_path).to(prompt_embeds.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's check the dtype of edit_direction here, think as execute this fun on cpu, the dtype will fp32 and if we are doing inference in fp16 it might fail. Could we also do everything on GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example code (#2334 (comment)) is run on a GPU and works without fail. Am I missing out on anything?
|
|
||
| with torch.enable_grad(): | ||
| # initialize loss | ||
| loss = Pix2PixZeroL2Loss() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are only using the class to store the running loss; we could also do it here directly. Would be more cleaner IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss is responsible for computation as well via compute_loss(). My reasoning is noted here: #2334 (comment).
| noise_pred = self.unet( | ||
| x_in.detach(), | ||
| t, | ||
| encoder_hidden_states=prompt_embeds_edit, | ||
| cross_attention_kwargs={"timestep": None}, | ||
| ).sample | ||
|
|
||
| latents = x_in.detach().chunk(2)[0] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since x_in is not used later in the code maybe we could detach it once and use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's used only once, I think it's okay.
Here a decorator that skip a whole test suite for MPS |
|
Thanks a lot @sayakpaul ! |
* add: support for BLIP generation. * add: support for editing synthetic images. * remove unnecessary comments. * add inits and run make fix-copies. * version change of diffusers. * fix: condition for loading the captioner. * default conditions_input_image to False. * guidance_amount -> cross_attention_guidance_amount * fix inputs to check_inputs() * fix: attribute. * fix: prepare_attention_mask() call. * debugging. * better placement of references. * remove torch.no_grad() decorations. * put torch.no_grad() context before the first denoising loop. * detach() latents before decoding them. * put deocding in a torch.no_grad() context. * add reconstructed image for debugging. * no_grad(0 * apply formatting. * address one-off suggestions from the draft PR. * back to torch.no_grad() and add more elaborate comments. * refactor prepare_unet() per Patrick's suggestions. * more elaborate description for . * formatting. * add docstrings to the methods specific to pix2pix zero. * suspecting a redundant noise prediction. * needed for gradient computation chain. * less hacks. * fix: attention mask handling within the processor. * remove attention reference map computation. * fix: cross attn args. * fix: prcoessor. * store attention maps. * fix: attention processor. * update docs and better treatment to xa args. * update the final noise computation call. * change xa args call. * remove xa args option from the pipeline. * add: docs. * first test. * fix: url call. * fix: argument call. * remove image conditioning for now. * 🚨 add: fast tests. * explicit placement of the xa attn weights. * add: slow tests 🐢 * fix: tests. * edited direction embedding should be on the same device as prompt_embeds. * debugging message. * debugging. * add pix2pix zero pipeline for a non-deterministic test. * debugging/ * remove debugging message. * make caption generation _ * address comments (part I). * address PR comments (part II) * fix: DDPM test assertion. * refactor doc. * address PR comments (part III). * fix: type annotation for the scheduler. * apply styling. * skip_mps and add note on embeddings in the docs.
* add: support for BLIP generation. * add: support for editing synthetic images. * remove unnecessary comments. * add inits and run make fix-copies. * version change of diffusers. * fix: condition for loading the captioner. * default conditions_input_image to False. * guidance_amount -> cross_attention_guidance_amount * fix inputs to check_inputs() * fix: attribute. * fix: prepare_attention_mask() call. * debugging. * better placement of references. * remove torch.no_grad() decorations. * put torch.no_grad() context before the first denoising loop. * detach() latents before decoding them. * put deocding in a torch.no_grad() context. * add reconstructed image for debugging. * no_grad(0 * apply formatting. * address one-off suggestions from the draft PR. * back to torch.no_grad() and add more elaborate comments. * refactor prepare_unet() per Patrick's suggestions. * more elaborate description for . * formatting. * add docstrings to the methods specific to pix2pix zero. * suspecting a redundant noise prediction. * needed for gradient computation chain. * less hacks. * fix: attention mask handling within the processor. * remove attention reference map computation. * fix: cross attn args. * fix: prcoessor. * store attention maps. * fix: attention processor. * update docs and better treatment to xa args. * update the final noise computation call. * change xa args call. * remove xa args option from the pipeline. * add: docs. * first test. * fix: url call. * fix: argument call. * remove image conditioning for now. * 🚨 add: fast tests. * explicit placement of the xa attn weights. * add: slow tests 🐢 * fix: tests. * edited direction embedding should be on the same device as prompt_embeds. * debugging message. * debugging. * add pix2pix zero pipeline for a non-deterministic test. * debugging/ * remove debugging message. * make caption generation _ * address comments (part I). * address PR comments (part II) * fix: DDPM test assertion. * refactor doc. * address PR comments (part III). * fix: type annotation for the scheduler. * apply styling. * skip_mps and add note on embeddings in the docs.
* add: support for BLIP generation. * add: support for editing synthetic images. * remove unnecessary comments. * add inits and run make fix-copies. * version change of diffusers. * fix: condition for loading the captioner. * default conditions_input_image to False. * guidance_amount -> cross_attention_guidance_amount * fix inputs to check_inputs() * fix: attribute. * fix: prepare_attention_mask() call. * debugging. * better placement of references. * remove torch.no_grad() decorations. * put torch.no_grad() context before the first denoising loop. * detach() latents before decoding them. * put deocding in a torch.no_grad() context. * add reconstructed image for debugging. * no_grad(0 * apply formatting. * address one-off suggestions from the draft PR. * back to torch.no_grad() and add more elaborate comments. * refactor prepare_unet() per Patrick's suggestions. * more elaborate description for . * formatting. * add docstrings to the methods specific to pix2pix zero. * suspecting a redundant noise prediction. * needed for gradient computation chain. * less hacks. * fix: attention mask handling within the processor. * remove attention reference map computation. * fix: cross attn args. * fix: prcoessor. * store attention maps. * fix: attention processor. * update docs and better treatment to xa args. * update the final noise computation call. * change xa args call. * remove xa args option from the pipeline. * add: docs. * first test. * fix: url call. * fix: argument call. * remove image conditioning for now. * 🚨 add: fast tests. * explicit placement of the xa attn weights. * add: slow tests 🐢 * fix: tests. * edited direction embedding should be on the same device as prompt_embeds. * debugging message. * debugging. * add pix2pix zero pipeline for a non-deterministic test. * debugging/ * remove debugging message. * make caption generation _ * address comments (part I). * address PR comments (part II) * fix: DDPM test assertion. * refactor doc. * address PR comments (part III). * fix: type annotation for the scheduler. * apply styling. * skip_mps and add note on embeddings in the docs.
Materials
Unlike most of the other pipelines, there's a mini training loop going inside of the pipeline since Pix2Pix Zero optimizes the attention maps to steer the image generation in the edited semantic directions obtained from source and target embeddings.
Sample code for inference
Results
Image generated by the prompt
Edited image with Pix2Pix Zero
TODOs