Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 13, 2023

Materials

Unlike most of the other pipelines, there's a mini training loop going inside of the pipeline since Pix2Pix Zero optimizes the attention maps to steer the image generation in the edited semantic directions obtained from source and target embeddings.

Sample code for inference

# Hyperparameters from
# https://github.com/pix2pixzero/pix2pix-zero/blob/main/src/edit_synthetic.py
import requests
import torch

from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline


def download(embedding_url, local_filepath):
    r = requests.get(embedding_url)
    with open(local_filepath, "wb") as f:
        f.write(r.content)


model_ckpt = "CompVis/stable-diffusion-v1-4"
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
    model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.to("cuda")


prompt = "a high resolution painting of a cat in the style of van gough"
source_embedding_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt"
target_embedding_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt"

for url in [source_embedding_url, target_embedding_url]:
    download(url, url.split("/")[-1])

source_embeds = torch.load(source_embedding_url.split("/")[-1])
target_embeds = torch.load(target_embedding_url.split("/")[-1])

images = pipeline(
    prompt,
    source_embeds=source_embeds,
    target_embeds=target_embeds,
    num_inference_steps=50,
    cross_attention_guidance_amount=0.15,
).images
images[0].save("edited_image_dog.png")

Results

Image generated by the prompt

image

Edited image with Pix2Pix Zero

image

TODOs

  • Add support for DDIM Inversion
  • Add documentation
  • Add tests

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@williamberman
Copy link
Contributor

Wow, this is super cool. I certainly did a double take when I saw the futzing with gradients inside the inference pipeline. Not necessary but if you have time, a quick blurb in the doc string on why pix2pix zero uses gradients during inference would be really cool/helpful for future readers

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Feb 15, 2023

Great work!

I'm in favor of not passing a path to the call function but instead force the user to load the tensor before hand and just pass a torch tensor.

Apart from this it looks great, thanks for iterating so quickly :-)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Thanks a lot for adding this pipeline so quickly. Agree with Patrick here that we should just pass the embeds directly instead of path.

I left some comments, let's make sure that pipeline works in fp16, some computation here might happen in fp32 and we don't cast it explicitly to the dtype of params. Would be cool to handle this.

pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True)
module.requires_grad_(True)
else:
pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the example, we are using only DDIM. Does it work with all schedulers? If not we should change type annotation and mention it in the docs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested it with the following ones (as can be seen in the test script except for DDPM):

  • DDIM (default)
  • EulerAncestralDiscreteScheduler
  • LMSDiscreteScheduler
  • DDPMScheduler

I will change the type annotation.

logger.info("Loading caption generator since `conditions_input_image` is True.")
checkpoint = "Salesforce/blip-image-captioning-base"
captioner_processor = AutoProcessor.from_pretrained(checkpoint)
captioner = BlipForConditionalGeneration.from_pretrained(checkpoint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure that captioner has same weight dtype as of other models.

# 2. Generate a caption for the input image if we are conditioning the
# pipeline based on some input image.
if self.conditions_input_image:
caption, preprocessed_image = generate_caption(image, self._captioner, self._captioner_processor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead passing the model and processor here, can it become it become the class method ?

callback(i, t, latents)

# 8. Compute the edit directions.
edit_direction = construct_direction(source_embedding_path, target_embedding_path).to(prompt_embeds.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make construct_direction a class method

callback(i, t, latents)

# 8. Compute the edit directions.
edit_direction = construct_direction(source_embedding_path, target_embedding_path).to(prompt_embeds.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check the dtype of edit_direction here, think as execute this fun on cpu, the dtype will fp32 and if we are doing inference in fp16 it might fail. Could we also do everything on GPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example code (#2334 (comment)) is run on a GPU and works without fail. Am I missing out on anything?


with torch.enable_grad():
# initialize loss
loss = Pix2PixZeroL2Loss()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are only using the class to store the running loss; we could also do it here directly. Would be more cleaner IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss is responsible for computation as well via compute_loss(). My reasoning is noted here: #2334 (comment).

Comment on lines +809 to +817
noise_pred = self.unet(
x_in.detach(),
t,
encoder_hidden_states=prompt_embeds_edit,
cross_attention_kwargs={"timestep": None},
).sample

latents = x_in.detach().chunk(2)[0]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since x_in is not used later in the code maybe we could detach it once and use it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's used only once, I think it's okay.

@patrickvonplaten
Copy link
Contributor

From CI:

E NotImplementedError: The operator 'aten::native_group_norm_backward' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch/pytorch#77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

What can we do about it? I guess it sort depends on the PyTorch team for now?

Here a decorator that skip a whole test suite for MPS

@sayakpaul sayakpaul requested review from patil-suraj and patrickvonplaten and removed request for patil-suraj February 15, 2023 12:10
@sayakpaul sayakpaul self-assigned this Feb 15, 2023
@patrickvonplaten patrickvonplaten merged commit fd3d550 into main Feb 16, 2023
@patrickvonplaten
Copy link
Contributor

Thanks a lot @sayakpaul !

@patrickvonplaten patrickvonplaten deleted the pix2pix-zero branch February 16, 2023 10:24
mengfei25 pushed a commit to mengfei25/diffusers that referenced this pull request Mar 27, 2023
* add: support for BLIP generation.

* add: support for editing synthetic images.

* remove unnecessary comments.

* add inits and run make fix-copies.

* version change of diffusers.

* fix: condition for loading the captioner.

* default conditions_input_image to False.

* guidance_amount -> cross_attention_guidance_amount

* fix inputs to check_inputs()

* fix: attribute.

* fix: prepare_attention_mask() call.

* debugging.

* better placement of references.

* remove torch.no_grad() decorations.

* put torch.no_grad() context before the first denoising loop.

* detach() latents before decoding them.

* put deocding in a torch.no_grad() context.

* add reconstructed image for debugging.

* no_grad(0

* apply formatting.

* address one-off suggestions from the draft PR.

* back to torch.no_grad() and add more elaborate comments.

* refactor prepare_unet() per Patrick's suggestions.

* more elaborate description for .

* formatting.

* add docstrings to the methods specific to pix2pix zero.

* suspecting a redundant noise prediction.

* needed for gradient computation chain.

* less hacks.

* fix: attention mask handling within the processor.

* remove attention reference map computation.

* fix: cross attn args.

* fix: prcoessor.

* store attention maps.

* fix: attention processor.

* update docs and better treatment to xa args.

* update the final noise computation call.

* change xa args call.

* remove xa args option from the pipeline.

* add: docs.

* first test.

* fix: url call.

* fix: argument call.

* remove image conditioning for now.

* 🚨 add: fast tests.

* explicit placement of the xa attn weights.

* add: slow tests 🐢

* fix: tests.

* edited direction embedding should be on the same device as prompt_embeds.

* debugging message.

* debugging.

* add pix2pix zero pipeline for a non-deterministic test.

* debugging/

* remove debugging message.

* make caption generation _

* address comments (part I).

* address PR comments (part II)

* fix: DDPM test assertion.

* refactor doc.

* address PR comments (part III).

* fix: type annotation for the scheduler.

* apply styling.

* skip_mps and add note on embeddings in the docs.
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add: support for BLIP generation.

* add: support for editing synthetic images.

* remove unnecessary comments.

* add inits and run make fix-copies.

* version change of diffusers.

* fix: condition for loading the captioner.

* default conditions_input_image to False.

* guidance_amount -> cross_attention_guidance_amount

* fix inputs to check_inputs()

* fix: attribute.

* fix: prepare_attention_mask() call.

* debugging.

* better placement of references.

* remove torch.no_grad() decorations.

* put torch.no_grad() context before the first denoising loop.

* detach() latents before decoding them.

* put deocding in a torch.no_grad() context.

* add reconstructed image for debugging.

* no_grad(0

* apply formatting.

* address one-off suggestions from the draft PR.

* back to torch.no_grad() and add more elaborate comments.

* refactor prepare_unet() per Patrick's suggestions.

* more elaborate description for .

* formatting.

* add docstrings to the methods specific to pix2pix zero.

* suspecting a redundant noise prediction.

* needed for gradient computation chain.

* less hacks.

* fix: attention mask handling within the processor.

* remove attention reference map computation.

* fix: cross attn args.

* fix: prcoessor.

* store attention maps.

* fix: attention processor.

* update docs and better treatment to xa args.

* update the final noise computation call.

* change xa args call.

* remove xa args option from the pipeline.

* add: docs.

* first test.

* fix: url call.

* fix: argument call.

* remove image conditioning for now.

* 🚨 add: fast tests.

* explicit placement of the xa attn weights.

* add: slow tests 🐢

* fix: tests.

* edited direction embedding should be on the same device as prompt_embeds.

* debugging message.

* debugging.

* add pix2pix zero pipeline for a non-deterministic test.

* debugging/

* remove debugging message.

* make caption generation _

* address comments (part I).

* address PR comments (part II)

* fix: DDPM test assertion.

* refactor doc.

* address PR comments (part III).

* fix: type annotation for the scheduler.

* apply styling.

* skip_mps and add note on embeddings in the docs.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add: support for BLIP generation.

* add: support for editing synthetic images.

* remove unnecessary comments.

* add inits and run make fix-copies.

* version change of diffusers.

* fix: condition for loading the captioner.

* default conditions_input_image to False.

* guidance_amount -> cross_attention_guidance_amount

* fix inputs to check_inputs()

* fix: attribute.

* fix: prepare_attention_mask() call.

* debugging.

* better placement of references.

* remove torch.no_grad() decorations.

* put torch.no_grad() context before the first denoising loop.

* detach() latents before decoding them.

* put deocding in a torch.no_grad() context.

* add reconstructed image for debugging.

* no_grad(0

* apply formatting.

* address one-off suggestions from the draft PR.

* back to torch.no_grad() and add more elaborate comments.

* refactor prepare_unet() per Patrick's suggestions.

* more elaborate description for .

* formatting.

* add docstrings to the methods specific to pix2pix zero.

* suspecting a redundant noise prediction.

* needed for gradient computation chain.

* less hacks.

* fix: attention mask handling within the processor.

* remove attention reference map computation.

* fix: cross attn args.

* fix: prcoessor.

* store attention maps.

* fix: attention processor.

* update docs and better treatment to xa args.

* update the final noise computation call.

* change xa args call.

* remove xa args option from the pipeline.

* add: docs.

* first test.

* fix: url call.

* fix: argument call.

* remove image conditioning for now.

* 🚨 add: fast tests.

* explicit placement of the xa attn weights.

* add: slow tests 🐢

* fix: tests.

* edited direction embedding should be on the same device as prompt_embeds.

* debugging message.

* debugging.

* add pix2pix zero pipeline for a non-deterministic test.

* debugging/

* remove debugging message.

* make caption generation _

* address comments (part I).

* address PR comments (part II)

* fix: DDPM test assertion.

* refactor doc.

* address PR comments (part III).

* fix: type annotation for the scheduler.

* apply styling.

* skip_mps and add note on embeddings in the docs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants