From 71618a485e59fbadd30f944dcd8f3304fcfb616f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 11:45:18 +0530 Subject: [PATCH 01/63] add: support for BLIP generation. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 711 ++++++++++++++++++ 1 file changed, 711 insertions(+) create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py new file mode 100644 index 000000000000..8d6b5771b55b --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -0,0 +1,711 @@ +# Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +from packaging import version +from transformers import ( + AutoProcessor, + BlipForConditionalGeneration, + CLIPFeatureExtractor, + CLIPTextModel, + CLIPTokenizer, +) +from transformers.utils import check_min_version + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_accelerate_available, logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +check_min_version("4.26.0") + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# def preprocess_image(processor, image): +# return processor(images=image, return_tensors="pt")["pixel_values"] + + +def generate_caption(image, captioner, processor, return_image=True): + inputs = processor(images=image, return_tensors="pt") + outputs = captioner.generate(inputs) + caption = processor.batch_deocde(outputs, skip_special_tokens=True)[0] + if return_image: + return caption, inputs["pixel_values"] + else: + return caption + + +class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + conditions_input_image: bool = True, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if conditions_input_image is None: + logger.info("Loading caption generator since `conditions_input_image` is True.") + checkpoint = "Salesforce/blip-image-captioning-base" + captioner_processor = AutoProcessor.from_pretrained(checkpoint) + captioner = BlipForConditionalGeneration.from_pretrained(checkpoint) + + else: + captioner_processor = None + captioner = None + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + captioner_processor=captioner_processor, + captioner=captioner, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config( + conditions_input_image=conditions_input_image, requires_safety_checker=requires_safety_checker + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + conditions_input_image, + image, + callback_steps, + prompt_embeds=None, + ): + # if height % 8 != 0 or width % 8 != 0: + # raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is None and not conditions_input_image: + raise ValueError(f"`prompt` cannot be None when `conditions_input_image` is {conditions_input_image}") + + elif prompt is not None and conditions_input_image: + raise ValueError( + f"`prompt` should not be provided when `conditions_input_image` is {conditions_input_image}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if conditions_input_image: + if image is None: + raise ValueError("`image` cannot be None when `conditions_input_image` is True.") + elif isinstance(image, (torch.FloatTensor, PIL.Image.Image)): + raise ValueError("Invalid image provided. Supported formats: torch.FloatTensor, PIL.Image.Image.}") + + # if negative_prompt is not None and negative_prompt_embeds is not None: + # raise ValueError( + # f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + # f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + # ) + + # if prompt_embeds is not None and negative_prompt_embeds is not None: + # if prompt_embeds.shape != negative_prompt_embeds.shape: + # raise ValueError( + # "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + # f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + # f" {negative_prompt_embeds.shape}." + # ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + # @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + # height: Optional[int] = None, + # width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # # 0. Default height and width to unet + # height = height or self.unet.config.sample_size * self.vae_scale_factor + # width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + self.conditions_input_image, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + if self.conditions_image_input and prompt_embeds: + logger.warning( + f"You have set `conditions_image_input` to {self.conditions_image_input} and" + " passed `prompt_embeds`. `prompt_embeds` will be ignored. " + ) + + # 2. Generate a caption for the input image if we are conditioning the + # pipeline based on some input image. + if self.conditions_image_input: + caption, preprocessed_image = generate_caption(image, self.captioner, self.captioner_processor) + height, width = preprocessed_image.shape[-2:] + prompt = caption + else: + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From f206b1dc4fe8521a531e3857201785c813e557a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:19:00 +0530 Subject: [PATCH 02/63] add: support for editing synthetic images. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 221 ++++++++++++++---- 1 file changed, 179 insertions(+), 42 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 8d6b5771b55b..290004a96d5f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -29,6 +29,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import CrossAttention from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -40,9 +41,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# def preprocess_image(processor, image): -# return processor(images=image, return_tensors="pt")["pixel_values"] - def generate_caption(image, captioner, processor, return_image=True): inputs = processor(images=image, return_tensors="pt") @@ -54,9 +52,62 @@ def generate_caption(image, captioner, processor, return_image=True): return caption +def prepare_unet(unet: UNet2DConditionModel): + # set the gradients for cross-attention maps to be true + for name, params in unet.named_parameters(): + if "attn2" in name: + params.requires_grad = True + else: + params.requires_grad = False + # replace the forward function + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_processor(Pix2PixZeroCrossAttnProcessor()) + return unet + + +def costruct_direction(source_embedding_path: str, target_embedding_path: str): + embs_source = torch.load(source_embedding_path) + embs_target = torch.load(target_embedding_path) + return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) + + +class Pix2PixZeroCrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + _, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + # new bookkeeping to save the attention weights. + attn.attn_probs = attention_probs + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): r""" - Pipeline for text-to-image generation using Stable Diffusion. + Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) @@ -495,8 +546,8 @@ def __call__( self, prompt: Optional[Union[str, List[str]]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, - # height: Optional[int] = None, - # width: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -506,6 +557,9 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + guidance_amount: float = 0.1, + source_embedding_path: str = None, + target_embedding_path: str = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -581,9 +635,9 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # # 0. Default height and width to unet - # height = height or self.unet.config.sample_size * self.vae_scale_factor - # width = width or self.unet.config.sample_size * self.vae_scale_factor + # 0. Define the spatial resolutions. + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -607,17 +661,18 @@ def __call__( caption, preprocessed_image = generate_caption(image, self.captioner, self.captioner_processor) height, width = preprocessed_image.shape[-2:] prompt = caption + logger.info(f"Generated caption for the input image: {caption}.") else: - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + pass - # 2. Define call parameters + # 3. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + ref_xa_maps = {} # reference cross attention maps device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -640,23 +695,34 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + # 5. Generate the inverted noise from the input image or any other image + # generated from the input prompt. + if self.conditions_input_image: + # TODO (sayakpaul): Generate this using DDIM inversion. + latents = 1 + # latents = self.prepare_inverted_noise() + else: + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents_init = latents.clone() # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Denoising loop + # 8. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -672,6 +738,14 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, ).sample + # add the cross attention map to the dictionary + ref_xa_maps[t.item()] = {} + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention" and "attn2" in name: + attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) + ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() + # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -686,26 +760,89 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) + # # 8. Post-process the reconstructed image. + # reconstructed_image = self.decode_latents(latents) + # reconstructed_image, recon_has_nsfw_concept = self.run_safety_checker( + # reconstructed_image, device, prompt_embeds.dtype + # ) + # if output_type == "pil": + # reconstructed_image = self.numpy_to_pil(reconstructed_image) + + # 8. Compute the edit directions. + edit_direction = costruct_direction( + source_embedding_path, target_embedding_path + ) # TODO (sayakpaul): compute the edit directions + + # 9. Edit the prompt embeddings as per the edit directions discovered. + prompt_embeds_edit = prompt_embeds.clone() + prompt_embeds_edit[1:2] += edit_direction + + # 10. Second denoising loop to generate the edited image. + latents = latents_init + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + x_in = latent_model_input.detach().clone() + x_in.requires_grad = True - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) + opt = torch.optim.SGD([x_in], lr=guidance_amount) + + # predict the noise residual + noise_pred = self.unet( + x_in, + t, + encoder_hidden_states=prompt_embeds_edit.detach(), + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + loss = 0.0 + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention" and "attn2" in name: + curr = module.attn_probs # size is num_channel,s*s,77 + ref = ref_xa_maps[t.item()][name].detach().cuda() + loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) + loss.backward(retain_graph=False) + opt.step() + + # recompute the noise + with torch.no_grad(): + noise_pred = self.unet( + x_in.detach(), + t, + encoder_hidden_states=prompt_embeds_edit, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + latents = x_in.detach().chunk(2)[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 11. Post-process the latents. + edited_image = self.decode_latents(latents) + + # 12. Run the safety checker. + edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + # 13. Convert to PIL. + if output_type == "pil": + edited_image = self.numpy_to_pil(edited_image) if not return_dict: - return (image, has_nsfw_concept) + return (edited_image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept) From 6ed7c204ac18fc1840d868562af84afde0d70385 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:20:58 +0530 Subject: [PATCH 03/63] remove unnecessary comments. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 290004a96d5f..6c274307df2f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -509,20 +509,6 @@ def check_inputs( elif isinstance(image, (torch.FloatTensor, PIL.Image.Image)): raise ValueError("Invalid image provided. Supported formats: torch.FloatTensor, PIL.Image.Image.}") - # if negative_prompt is not None and negative_prompt_embeds is not None: - # raise ValueError( - # f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - # f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - # ) - - # if prompt_embeds is not None and negative_prompt_embeds is not None: - # if prompt_embeds.shape != negative_prompt_embeds.shape: - # raise ValueError( - # "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - # f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - # f" {negative_prompt_embeds.shape}." - # ) - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -760,14 +746,6 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # # 8. Post-process the reconstructed image. - # reconstructed_image = self.decode_latents(latents) - # reconstructed_image, recon_has_nsfw_concept = self.run_safety_checker( - # reconstructed_image, device, prompt_embeds.dtype - # ) - # if output_type == "pil": - # reconstructed_image = self.numpy_to_pil(reconstructed_image) - # 8. Compute the edit directions. edit_direction = costruct_direction( source_embedding_path, target_embedding_path From f95a3fce95cb325d273f7fb9a27da3102e2efd08 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:24:37 +0530 Subject: [PATCH 04/63] add inits and run make fix-copies. --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stable_diffusion/__init__.py | 1 + .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 4 files changed, 18 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bc6057eaf2da..a407c8b3e358 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -118,6 +118,7 @@ StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, + StableDiffusionPix2PixZeroPipeline, StableDiffusionUpscalePipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index dfb2fd83cb71..a61e2b5668f2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -54,6 +54,7 @@ StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline, + StableDiffusionPix2PixZeroPipeline, StableDiffusionUpscalePipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index bf07127cde5b..bd4a5588ff64 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -44,6 +44,7 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline + from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 79755c27e6fe..0f5532efa8e0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -212,6 +212,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 3881408d2a4402ca44330a0da722a6bf491e8c4c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:38:19 +0530 Subject: [PATCH 05/63] version change of diffusers. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 6c274307df2f..513ab598c5cd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -37,7 +37,7 @@ from .safety_checker import StableDiffusionSafetyChecker -check_min_version("4.26.0") +check_min_version("4.26.1") logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 9fe140bf390cab50db7931a00c00f02f74ecd838 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:42:25 +0530 Subject: [PATCH 06/63] fix: condition for loading the captioner. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 513ab598c5cd..a30136854f6d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -185,7 +185,7 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - if conditions_input_image is None: + if conditions_input_image: logger.info("Loading caption generator since `conditions_input_image` is True.") checkpoint = "Salesforce/blip-image-captioning-base" captioner_processor = AutoProcessor.from_pretrained(checkpoint) From d8de3e0eb2a2656426d9285899a3c1d62c728335 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:46:44 +0530 Subject: [PATCH 07/63] default conditions_input_image to False. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index a30136854f6d..9797c1f8413d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -143,7 +143,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, - conditions_input_image: bool = True, + conditions_input_image: bool = False, requires_safety_checker: bool = True, ): super().__init__() From e3f51c56aa54672ad2a614fe6cda08b4e9d5b802 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:49:16 +0530 Subject: [PATCH 08/63] guidance_amount -> cross_attention_guidance_amount --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 9797c1f8413d..94c555df1117 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -543,7 +543,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - guidance_amount: float = 0.1, + cross_attention_guidance_amount: float = 0.1, source_embedding_path: str = None, target_embedding_path: str = None, output_type: Optional[str] = "pil", @@ -767,7 +767,7 @@ def __call__( x_in = latent_model_input.detach().clone() x_in.requires_grad = True - opt = torch.optim.SGD([x_in], lr=guidance_amount) + opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) # predict the noise residual noise_pred = self.unet( From e08b254961be5cf9677ddcb5953b4999705b1d3b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 15:59:47 +0530 Subject: [PATCH 09/63] fix inputs to check_inputs() --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 94c555df1117..8d9f66e969f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -631,9 +631,7 @@ def __call__( self.conditions_input_image, image, callback_steps, - negative_prompt, prompt_embeds, - negative_prompt_embeds, ) if self.conditions_image_input and prompt_embeds: logger.warning( From 45b173980e3f292146757f0ecc3b1b49dcf3ec14 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 16:03:25 +0530 Subject: [PATCH 10/63] fix: attribute. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 8d9f66e969f8..14130fd08c10 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -234,9 +234,8 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config( - conditions_input_image=conditions_input_image, requires_safety_checker=requires_safety_checker - ) + self.conditions_input_image = conditions_input_image + self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" @@ -633,15 +632,15 @@ def __call__( callback_steps, prompt_embeds, ) - if self.conditions_image_input and prompt_embeds: + if self.conditions_input_image and prompt_embeds: logger.warning( - f"You have set `conditions_image_input` to {self.conditions_image_input} and" + f"You have set `conditions_input_image` to {self.conditions_input_image} and" " passed `prompt_embeds`. `prompt_embeds` will be ignored. " ) # 2. Generate a caption for the input image if we are conditioning the # pipeline based on some input image. - if self.conditions_image_input: + if self.conditions_input_image: caption, preprocessed_image = generate_caption(image, self.captioner, self.captioner_processor) height, width = preprocessed_image.shape[-2:] prompt = caption From d44d6992c8cc32b83c41bd981fa936ab3565375e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 16:18:18 +0530 Subject: [PATCH 11/63] fix: prepare_attention_mask() call. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 14130fd08c10..64f6f2770dd0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -75,8 +75,8 @@ def costruct_direction(source_embedding_path: str, target_embedding_path: str): class Pix2PixZeroCrossAttnProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - _, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: From 2f71a8c163a0dbeb1ae6b9923c98f86d828ad9c3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 16:34:03 +0530 Subject: [PATCH 12/63] debugging. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 64f6f2770dd0..0c1a8350f453 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -778,7 +778,9 @@ def __call__( for name, module in self.unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: - curr = module.attn_probs # size is num_channel,s*s,77 + for _, param in module.named_parameters(): + print(f"{module_name}: {param.requires_grad}") + curr = module.attn_probs ref = ref_xa_maps[t.item()][name].detach().cuda() loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) loss.backward(retain_graph=False) From f7576634d167da4fc469c54aeb026a6146757734 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 16:56:15 +0530 Subject: [PATCH 13/63] better placement of references. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 0c1a8350f453..2f4bf5431d62 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -67,7 +67,7 @@ def prepare_unet(unet: UNet2DConditionModel): return unet -def costruct_direction(source_embedding_path: str, target_embedding_path: str): +def construct_direction(source_embedding_path: str, target_embedding_path: str): embs_source = torch.load(source_embedding_path) embs_target = torch.load(target_embedding_path) return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) @@ -744,9 +744,7 @@ def __call__( callback(i, t, latents) # 8. Compute the edit directions. - edit_direction = costruct_direction( - source_embedding_path, target_embedding_path - ) # TODO (sayakpaul): compute the edit directions + edit_direction = construct_direction(source_embedding_path, target_embedding_path) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() @@ -778,10 +776,11 @@ def __call__( for name, module in self.unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: - for _, param in module.named_parameters(): - print(f"{module_name}: {param.requires_grad}") + # for _, param in module.named_parameters(): + # print(f"{module_name}: {param.requires_grad}") curr = module.attn_probs - ref = ref_xa_maps[t.item()][name].detach().cuda() + ref = ref_xa_maps[t.item()][name].detach().to(device) + print(curr.requires_grad, ref.requires_grad) loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) loss.backward(retain_graph=False) opt.step() From 5d44832fcbc482f38486e5e7dd1acb7d9dd218b6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 17:33:53 +0530 Subject: [PATCH 14/63] remove torch.no_grad() decorations. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 2f4bf5431d62..ca739af765d4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -525,7 +525,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - @torch.no_grad() # @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, @@ -776,11 +775,8 @@ def __call__( for name, module in self.unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: - # for _, param in module.named_parameters(): - # print(f"{module_name}: {param.requires_grad}") curr = module.attn_probs ref = ref_xa_maps[t.item()][name].detach().to(device) - print(curr.requires_grad, ref.requires_grad) loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) loss.backward(retain_graph=False) opt.step() From c4fb4ad0b5e05b574f5bb4469800d83f23f1bb2c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 17:37:58 +0530 Subject: [PATCH 15/63] put torch.no_grad() context before the first denoising loop. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 69 ++++++++++--------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index ca739af765d4..2b216a3dcfe1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -706,41 +706,42 @@ def __call__( # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # add the cross attention map to the dictionary - ref_xa_maps[t.item()] = {} - for name, module in self.unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention" and "attn2" in name: - attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) - ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + with torch.no_grad(): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # add the cross attention map to the dictionary + ref_xa_maps[t.item()] = {} + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention" and "attn2" in name: + attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) + ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Compute the edit directions. edit_direction = construct_direction(source_embedding_path, target_embedding_path) From ade51d76a0a9b8052bdc132da04b8bc7c09d7aac Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 17:40:36 +0530 Subject: [PATCH 16/63] detach() latents before decoding them. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 2b216a3dcfe1..22a00964c00d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -806,7 +806,7 @@ def __call__( progress_bar.update() # 11. Post-process the latents. - edited_image = self.decode_latents(latents) + edited_image = self.decode_latents(latents.detach()) # 12. Run the safety checker. edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) From e23fe7b860691536d8fe96fe7d282da536ba5d28 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 17:44:26 +0530 Subject: [PATCH 17/63] put deocding in a torch.no_grad() context. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 22a00964c00d..c230da02ca7e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -806,7 +806,8 @@ def __call__( progress_bar.update() # 11. Post-process the latents. - edited_image = self.decode_latents(latents.detach()) + with torch.no_grad(): + edited_image = self.decode_latents(latents) # 12. Run the safety checker. edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) From c76c16228850823b2117a0396b01d9b6d9f9ad1e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 17:47:38 +0530 Subject: [PATCH 18/63] add reconstructed image for debugging. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index c230da02ca7e..4fceefc69d81 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -746,6 +746,10 @@ def __call__( # 8. Compute the edit directions. edit_direction = construct_direction(source_embedding_path, target_embedding_path) + # make the reference image (reconstruction) + image_rec = self.numpy_to_pil(self.decode_latents(latents.detach())) + image_rec[0].save("reconstructed_image.png") + # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() prompt_embeds_edit[1:2] += edit_direction From b02016bb1fb0fa24d655573a7827e1f717a0f5f7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 17:48:48 +0530 Subject: [PATCH 19/63] no_grad(0 --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 4fceefc69d81..06f6f7d1faa4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -747,7 +747,8 @@ def __call__( edit_direction = construct_direction(source_embedding_path, target_embedding_path) # make the reference image (reconstruction) - image_rec = self.numpy_to_pil(self.decode_latents(latents.detach())) + with torch.no_grad(): + image_rec = self.numpy_to_pil(self.decode_latents(latents)) image_rec[0].save("reconstructed_image.png") # 9. Edit the prompt embeddings as per the edit directions discovered. From 55414e02f0de6962bde74dd4ebf6dfd770b02cf2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Feb 2023 18:36:04 +0530 Subject: [PATCH 20/63] apply formatting. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 06f6f7d1faa4..ef5014c89801 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -235,24 +235,13 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.conditions_input_image = conditions_input_image - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() + self.register_to_config( + captioner=captioner, + captioner_processor=captioner_processor, + requires_safety_checker=requires_safety_checker, + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, @@ -273,6 +262,7 @@ def enable_sequential_cpu_offload(self, gpu_id=0): cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling @@ -290,6 +280,7 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, @@ -428,6 +419,7 @@ def _encode_prompt( return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) @@ -438,6 +430,7 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample @@ -446,6 +439,7 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -471,9 +465,6 @@ def check_inputs( callback_steps, prompt_embeds=None, ): - # if height % 8 != 0 or width % 8 != 0: - # raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -508,6 +499,7 @@ def check_inputs( elif isinstance(image, (torch.FloatTensor, PIL.Image.Image)): raise ValueError("Invalid image provided. Supported formats: torch.FloatTensor, PIL.Image.Image.}") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -744,12 +736,7 @@ def __call__( callback(i, t, latents) # 8. Compute the edit directions. - edit_direction = construct_direction(source_embedding_path, target_embedding_path) - - # make the reference image (reconstruction) - with torch.no_grad(): - image_rec = self.numpy_to_pil(self.decode_latents(latents)) - image_rec[0].save("reconstructed_image.png") + edit_direction = construct_direction(source_embedding_path, target_embedding_path) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() @@ -781,7 +768,7 @@ def __call__( for name, module in self.unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: - curr = module.attn_probs + curr = module.attn_probs ref = ref_xa_maps[t.item()][name].detach().to(device) loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) loss.backward(retain_graph=False) From b07d2cd1788519f1b2d330c5ade1eb4bb92a262a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 08:26:46 +0530 Subject: [PATCH 21/63] address one-off suggestions from the draft PR. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 60 ++----------------- 1 file changed, 6 insertions(+), 54 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index ef5014c89801..402e8dad3401 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -17,7 +17,6 @@ import PIL import torch -from packaging import version from transformers import ( AutoProcessor, BlipForConditionalGeneration, @@ -25,20 +24,16 @@ CLIPTextModel, CLIPTokenizer, ) -from transformers.utils import check_min_version -from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import CrossAttention from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, is_accelerate_available, logging, randn_tensor +from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker -check_min_version("4.26.1") - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -131,6 +126,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. + conditions_input_image (bool): + Whether to execute the pipeline with an input image. + requires_safety_checker (bool): + Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the + pipeline publicly. """ _optional_components = ["safety_checker", "feature_extractor"] @@ -148,33 +148,6 @@ def __init__( ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" @@ -201,27 +174,6 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, From 9bdade105de1bccf1f5d22027d27e42ce04b395a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 08:38:08 +0530 Subject: [PATCH 22/63] back to torch.no_grad() and add more elaborate comments. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 115 +++++++++--------- 1 file changed, 60 insertions(+), 55 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 402e8dad3401..3b9efdbad19b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -469,6 +469,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + @torch.no_grad() # @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, @@ -650,42 +651,41 @@ def __call__( # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with torch.no_grad(): - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample - # add the cross attention map to the dictionary - ref_xa_maps[t.item()] = {} - for name, module in self.unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention" and "attn2" in name: - attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) - ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() + # add the cross attention map to the dictionary + ref_xa_maps[t.item()] = {} + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention" and "attn2" in name: + attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) + ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Compute the edit directions. edit_direction = construct_direction(source_embedding_path, target_embedding_path) @@ -703,38 +703,44 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # we want to learn the latent such that it steers the generation + # process towards the edited direction, so make the make initial + # noise learnable x_in = latent_model_input.detach().clone() x_in.requires_grad = True + # optimizer opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) - # predict the noise residual - noise_pred = self.unet( - x_in, - t, - encoder_hidden_states=prompt_embeds_edit.detach(), - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - loss = 0.0 - for name, module in self.unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention" and "attn2" in name: - curr = module.attn_probs - ref = ref_xa_maps[t.item()][name].detach().to(device) - loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) - loss.backward(retain_graph=False) - opt.step() - - # recompute the noise - with torch.no_grad(): + with torch.enable_grad(): + # predict the noise residual noise_pred = self.unet( - x_in.detach(), + x_in, t, - encoder_hidden_states=prompt_embeds_edit, + encoder_hidden_states=prompt_embeds_edit.detach(), cross_attention_kwargs=cross_attention_kwargs, ).sample + # obtain the cross-attention maps with the current latents, + # compute loss, backpropagate + loss = 0.0 + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention" and "attn2" in name: + curr = module.attn_probs + ref = ref_xa_maps[t.item()][name].detach().to(device) + loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) + loss.backward(retain_graph=False) + opt.step() + + # recompute the noise + noise_pred = self.unet( + x_in.detach(), + t, + encoder_hidden_states=prompt_embeds_edit, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + latents = x_in.detach().chunk(2)[0] # perform guidance @@ -750,8 +756,7 @@ def __call__( progress_bar.update() # 11. Post-process the latents. - with torch.no_grad(): - edited_image = self.decode_latents(latents) + edited_image = self.decode_latents(latents) # 12. Run the safety checker. edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) From 8e9736652618f40b0467523b8dd7ff4394493c37 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 10:19:22 +0530 Subject: [PATCH 23/63] refactor prepare_unet() per Patrick's suggestions. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 3b9efdbad19b..46f35d9987a4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -26,7 +26,7 @@ ) from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention import CrossAttention +from ...models.cross_attention import CrossAttention, CrossAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -48,17 +48,18 @@ def generate_caption(image, captioner, processor, return_image=True): def prepare_unet(unet: UNet2DConditionModel): - # set the gradients for cross-attention maps to be true - for name, params in unet.named_parameters(): + pix2pix_zero_attn_procs = {} + for name in unet.attn_processors.keys(): + module_name = name.replace(".processor", "") + module = unet.get_submodule(module_name) if "attn2" in name: - params.requires_grad = True + pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor() + module.requires_grad_(True) else: - params.requires_grad = False - # replace the forward function - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.set_processor(Pix2PixZeroCrossAttnProcessor()) + pix2pix_zero_attn_procs[name] = CrossAttnProcessor() + module.requires_grad_(False) + + unet.set_attn_processor(pix2pix_zero_attn_procs) return unet From 7f61cdb2fa9f25de922b06fa566d06402d7d1fc3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 10:20:25 +0530 Subject: [PATCH 24/63] more elaborate description for . --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 46f35d9987a4..cd710c8a1b56 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -128,7 +128,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. conditions_input_image (bool): - Whether to execute the pipeline with an input image. + Whether to condition the pipeline with an input image to compute an inverted noise latent. requires_safety_checker (bool): Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the pipeline publicly. From b99e508c722dfc200d769eb7284f8768a9602665 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 11:18:01 +0530 Subject: [PATCH 25/63] formatting. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index cd710c8a1b56..ca28ff278905 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -164,7 +164,6 @@ def __init__( checkpoint = "Salesforce/blip-image-captioning-base" captioner_processor = AutoProcessor.from_pretrained(checkpoint) captioner = BlipForConditionalGeneration.from_pretrained(checkpoint) - else: captioner_processor = None captioner = None From db2136ac41c32240fb149079cdd30a68996a814b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 11:31:53 +0530 Subject: [PATCH 26/63] add docstrings to the methods specific to pix2pix zero. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index ca28ff278905..f7f999acbd78 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -38,6 +38,7 @@ def generate_caption(image, captioner, processor, return_image=True): + """Generates caption for a given image.""" inputs = processor(images=image, return_tensors="pt") outputs = captioner.generate(inputs) caption = processor.batch_deocde(outputs, skip_special_tokens=True)[0] @@ -48,6 +49,7 @@ def generate_caption(image, captioner, processor, return_image=True): def prepare_unet(unet: UNet2DConditionModel): + """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" pix2pix_zero_attn_procs = {} for name in unet.attn_processors.keys(): module_name = name.replace(".processor", "") @@ -64,6 +66,7 @@ def prepare_unet(unet: UNet2DConditionModel): def construct_direction(source_embedding_path: str, target_embedding_path: str): + """Constructs the edit direction to steer the image generation process semantically.""" embs_source = torch.load(source_embedding_path) embs_target = torch.load(target_embedding_path) return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) From a64d6467576ab1346fa65e4dd68ce969ab8a38d3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 16:09:28 +0530 Subject: [PATCH 27/63] suspecting a redundant noise prediction. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index f7f999acbd78..69461fda163e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -717,12 +717,12 @@ def __call__( with torch.enable_grad(): # predict the noise residual - noise_pred = self.unet( - x_in, - t, - encoder_hidden_states=prompt_embeds_edit.detach(), - cross_attention_kwargs=cross_attention_kwargs, - ).sample + # noise_pred = self.unet( + # x_in, + # t, + # encoder_hidden_states=prompt_embeds_edit.detach(), + # cross_attention_kwargs=cross_attention_kwargs, + # ).sample # obtain the cross-attention maps with the current latents, # compute loss, backpropagate From 1afd2b8c908df11a8c204a5f0b19015ff476a697 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 16:11:29 +0530 Subject: [PATCH 28/63] needed for gradient computation chain. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 69461fda163e..f7f999acbd78 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -717,12 +717,12 @@ def __call__( with torch.enable_grad(): # predict the noise residual - # noise_pred = self.unet( - # x_in, - # t, - # encoder_hidden_states=prompt_embeds_edit.detach(), - # cross_attention_kwargs=cross_attention_kwargs, - # ).sample + noise_pred = self.unet( + x_in, + t, + encoder_hidden_states=prompt_embeds_edit.detach(), + cross_attention_kwargs=cross_attention_kwargs, + ).sample # obtain the cross-attention maps with the current latents, # compute loss, backpropagate From b6655f9d3b9c9c25d923685eb59b325b16a02949 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 16:49:00 +0530 Subject: [PATCH 29/63] less hacks. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 66 +++++++++++++------ 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index f7f999acbd78..aae9d57e3927 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -26,7 +26,7 @@ ) from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.cross_attention import CrossAttention, CrossAttnProcessor +from ...models.cross_attention import CrossAttention from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -55,10 +55,10 @@ def prepare_unet(unet: UNet2DConditionModel): module_name = name.replace(".processor", "") module = unet.get_submodule(module_name) if "attn2" in name: - pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor() + pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True) module.requires_grad_(True) else: - pix2pix_zero_attn_procs[name] = CrossAttnProcessor() + pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False) module.requires_grad_(False) unet.set_attn_processor(pix2pix_zero_attn_procs) @@ -72,16 +72,34 @@ def construct_direction(source_embedding_path: str, target_embedding_path: str): return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) +class Pix2PixZeroL2Loss: + def __init__(self): + self.loss = 0.0 + + def compute_loss(self, predictions, targets): + self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) + + class Pix2PixZeroCrossAttnProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __init__(self, is_pix2pix_zero=False): + self.is_pix2pix_zero = is_pix2pix_zero + if self.is_pix2pix_zero: + self.xa_map = {} + + def __call__( + self, + attn: CrossAttention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + timestep=None, + loss=None, + ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -91,8 +109,15 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) - # new bookkeeping to save the attention weights. - attn.attn_probs = attention_probs + if self.is_pix2pix_zero: + # new bookkeeping to save the attention weights. + if loss is None: + self.xa_map[timestep] = attention_probs + # compute loss + elif loss is not None: + prev_attn_probs = self.xa_map.pop(timestep) + loss.compute_loss(attention_probs, prev_attn_probs) + hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) @@ -716,6 +741,9 @@ def __call__( opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) with torch.enable_grad(): + # initialize loss + loss = Pix2PixZeroL2Loss() + # predict the noise residual noise_pred = self.unet( x_in, @@ -726,14 +754,14 @@ def __call__( # obtain the cross-attention maps with the current latents, # compute loss, backpropagate - loss = 0.0 - for name, module in self.unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention" and "attn2" in name: - curr = module.attn_probs - ref = ref_xa_maps[t.item()][name].detach().to(device) - loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) - loss.backward(retain_graph=False) + # loss = 0.0 + # for name, module in self.unet.named_modules(): + # module_name = type(module).__name__ + # if module_name == "CrossAttention" and "attn2" in name: + # curr = module.attn_probs + # ref = ref_xa_maps[t.item()][name].detach().to(device) + # loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) + loss.loss.backward(retain_graph=False) opt.step() # recompute the noise From 30cc5ef0669eef8be5edcaf49ca7b73405c62c26 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 16:52:06 +0530 Subject: [PATCH 30/63] fix: attention mask handling within the processor. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index aae9d57e3927..5bb8f85f937c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -96,10 +96,13 @@ def __call__( loss=None, ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) From ee59e57ad24dab804dcc522e56260bac48a4859c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 16:54:44 +0530 Subject: [PATCH 31/63] remove attention reference map computation. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 5bb8f85f937c..ed246eb5597a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -696,13 +696,13 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, ).sample - # add the cross attention map to the dictionary - ref_xa_maps[t.item()] = {} - for name, module in self.unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention" and "attn2" in name: - attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) - ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() + # # add the cross attention map to the dictionary + # ref_xa_maps[t.item()] = {} + # for name, module in self.unet.named_modules(): + # module_name = type(module).__name__ + # if module_name == "CrossAttention" and "attn2" in name: + # attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) + # ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() # perform guidance if do_classifier_free_guidance: From d9cc312f8f468efe514c3b2a53cb6538f7b87f12 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 17:01:00 +0530 Subject: [PATCH 32/63] fix: cross attn args. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index ed246eb5597a..19ac2951b279 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -620,8 +620,6 @@ def __call__( height, width = preprocessed_image.shape[-2:] prompt = caption logger.info(f"Generated caption for the input image: {caption}.") - else: - pass # 3. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -630,7 +628,7 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - ref_xa_maps = {} # reference cross attention maps + # ref_xa_maps = {} # reference cross attention maps device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -752,7 +750,7 @@ def __call__( x_in, t, encoder_hidden_states=prompt_embeds_edit.detach(), - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs={"timestep": t, "loss": loss}, ).sample # obtain the cross-attention maps with the current latents, From e54ff0f88b253a7ab0d8584ff86cd100ea39af6c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 17:07:17 +0530 Subject: [PATCH 33/63] fix: prcoessor. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 19ac2951b279..daeb2cf4f849 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -115,11 +115,11 @@ def __call__( if self.is_pix2pix_zero: # new bookkeeping to save the attention weights. if loss is None: - self.xa_map[timestep] = attention_probs + self.xa_map[timestep.item()] = attention_probs # compute loss elif loss is not None: - prev_attn_probs = self.xa_map.pop(timestep) - loss.compute_loss(attention_probs, prev_attn_probs) + prev_attn_probs = self.xa_map.pop(timestep.item()) + loss.compute_loss(attention_probs, prev_attn_probs.detach()) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) From a5d95d9f9447606e4da3f81ca2a64e30b923b5a5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 17:09:40 +0530 Subject: [PATCH 34/63] store attention maps. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index daeb2cf4f849..21b9452c6787 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -691,7 +691,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs={"timestep": t}, ).sample # # add the cross attention map to the dictionary From 02faceee6f93ae94da39332ebbc447d5ea41dac2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Feb 2023 17:13:32 +0530 Subject: [PATCH 35/63] fix: attention processor. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 21b9452c6787..8521dfd751ff 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -112,7 +112,7 @@ def __call__( value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) - if self.is_pix2pix_zero: + if self.is_pix2pix_zero and timestep is not None: # new bookkeeping to save the attention weights. if loss is None: self.xa_map[timestep.item()] = attention_probs @@ -770,7 +770,7 @@ def __call__( x_in.detach(), t, encoder_hidden_states=prompt_embeds_edit, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs={"timestep": None}, ).sample latents = x_in.detach().chunk(2)[0] From 26be9a1627b16f59a4abf861336d6da4656cca1b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 08:31:24 +0530 Subject: [PATCH 36/63] update docs and better treatment to xa args. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 92 ++++++++++++++----- 1 file changed, 68 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 8521dfd751ff..a164f385b367 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -28,7 +28,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.cross_attention import CrossAttention from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor +from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -36,6 +36,45 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + + >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline + + + >>> def download(embedding_url, local_filepath): + ... r = requests.get(embedding_url) + ... with open(local_filepath, "wb") as f: + ... f.write(r.content) + + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + ... model_ckpt, conditions_input_image=False, torch_dtype=torch.float16 + ... ) + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.to("cuda") + + >>> prompt = "a high resolution painting of a cat in the style of van gough" + >>> source_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" + >>> target_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" + + >>> for url in [source_emb_url, target_emb_url]: + ... download(url, url.split("/")[-1]) + >>> images = pipeline( + ... prompt, + ... source_embedding_path=source_emb_url.split("/")[-1], + ... target_embedding_path=target_emb_url.split("/")[-1], + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... ).images + >>> images[0].save("edited_image_dog.png") + ``` +""" + def generate_caption(image, captioner, processor, return_image=True): """Generates caption for a given image.""" @@ -81,6 +120,9 @@ def compute_loss(self, predictions, targets): class Pix2PixZeroCrossAttnProcessor: + """An attention processor class to store the attention weights. + In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" + def __init__(self, is_pix2pix_zero=False): self.is_pix2pix_zero = is_pix2pix_zero if self.is_pix2pix_zero: @@ -445,6 +487,8 @@ def check_inputs( prompt, conditions_input_image, image, + source_embedding_path, + target_embedding_path, callback_steps, prompt_embeds=None, ): @@ -455,6 +499,8 @@ def check_inputs( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if source_embedding_path is None and target_embedding_path is None: + raise ValueError("`source_embedding_path` and `target_embedding_path` cannot be undefined.") if prompt is None and not conditions_input_image: raise ValueError(f"`prompt` cannot be None when `conditions_input_image` is {conditions_input_image}") @@ -501,7 +547,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents @torch.no_grad() - # @replace_example_docstring(EXAMPLE_DOC_STRING) + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, @@ -533,6 +579,8 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be used for conditioning. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -569,6 +617,14 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + source_embedding_path (`str`, defaults to None): + Local filepath to the embeddings of the source concept. Generation of the embeddings as per the + [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. + target_embedding_path (`str`, defaults to None): + Local filepath to the embeddings of the target concept. Generation of the embeddings as per the + [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -604,6 +660,8 @@ def __call__( prompt, self.conditions_input_image, image, + source_embedding_path, + target_embedding_path, callback_steps, prompt_embeds, ) @@ -628,7 +686,8 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - # ref_xa_maps = {} # reference cross attention maps + if cross_attention_kwargs is None: + cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -655,8 +714,10 @@ def __call__( # generated from the input prompt. if self.conditions_input_image: # TODO (sayakpaul): Generate this using DDIM inversion. - latents = 1 - # latents = self.prepare_inverted_noise() + # We need to get the inverted noise from the input image and this requires + # us to do a sort of `inverse_step()` in DDIM and then regularize the + # noise to enforce the statistical properties of Gaussian. + raise NotImplementedError else: num_channels_latents = self.unet.in_channels latents = self.prepare_latents( @@ -691,17 +752,9 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs={"timestep": t}, + cross_attention_kwargs=cross_attention_kwargs.update({"timestep": t}), ).sample - # # add the cross attention map to the dictionary - # ref_xa_maps[t.item()] = {} - # for name, module in self.unet.named_modules(): - # module_name = type(module).__name__ - # if module_name == "CrossAttention" and "attn2" in name: - # attn_mask = module.attn_probs # size is (num_channels, s*s, max_prompt_length) - # ref_xa_maps[t.item()][name] = attn_mask.detach().cpu() - # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -750,18 +803,9 @@ def __call__( x_in, t, encoder_hidden_states=prompt_embeds_edit.detach(), - cross_attention_kwargs={"timestep": t, "loss": loss}, + cross_attention_kwargs=cross_attention_kwargs.update({"timestep": t, "loss": loss}), ).sample - # obtain the cross-attention maps with the current latents, - # compute loss, backpropagate - # loss = 0.0 - # for name, module in self.unet.named_modules(): - # module_name = type(module).__name__ - # if module_name == "CrossAttention" and "attn2" in name: - # curr = module.attn_probs - # ref = ref_xa_maps[t.item()][name].detach().to(device) - # loss += ((curr - ref) ** 2).sum((1, 2)).mean(0) loss.loss.backward(retain_graph=False) opt.step() From 57e1709e4a9c006715e139e9c5f4f29ce291a88b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 08:32:36 +0530 Subject: [PATCH 37/63] update the final noise computation call. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index a164f385b367..14e409b1967d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -814,7 +814,7 @@ def __call__( x_in.detach(), t, encoder_hidden_states=prompt_embeds_edit, - cross_attention_kwargs={"timestep": None}, + cross_attention_kwargs=cross_attention_kwargs.update({"timestep": None}), ).sample latents = x_in.detach().chunk(2)[0] From f37e25a70f46e8d2f6b6bc2872332cf45f4e76ac Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 08:37:54 +0530 Subject: [PATCH 38/63] change xa args call. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 14e409b1967d..2f899d2e4d6b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -752,7 +752,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs.update({"timestep": t}), + cross_attention_kwargs={"timestep": t}, ).sample # perform guidance @@ -803,7 +803,7 @@ def __call__( x_in, t, encoder_hidden_states=prompt_embeds_edit.detach(), - cross_attention_kwargs=cross_attention_kwargs.update({"timestep": t, "loss": loss}), + cross_attention_kwargs={"timestep": t, "loss": loss}, ).sample loss.loss.backward(retain_graph=False) @@ -814,7 +814,7 @@ def __call__( x_in.detach(), t, encoder_hidden_states=prompt_embeds_edit, - cross_attention_kwargs=cross_attention_kwargs.update({"timestep": None}), + cross_attention_kwargs={"timestep": None}, ).sample latents = x_in.detach().chunk(2)[0] From 2d551628bee99f7a41d4f1bf40090172c7571c51 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 08:39:50 +0530 Subject: [PATCH 39/63] remove xa args option from the pipeline. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 2f899d2e4d6b..077e9463f607 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -637,10 +637,6 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: From 8a17a17bbc40c3c4a4994b858e3132e6ade343c9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 08:58:38 +0530 Subject: [PATCH 40/63] add: docs. --- docs/source/en/_toctree.yml | 2 + .../pipelines/stable_diffusion/overview.mdx | 1 + .../stable_diffusion/pix2pix_zero.mdx | 92 +++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index dd08e9f95d99..51bcea22b61e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -149,6 +149,8 @@ title: Stable-Diffusion-Latent-Upscaler - local: api/pipelines/stable_diffusion/pix2pix title: InstructPix2Pix + - local: api/pipelines/stable_diffusion/pix2pix_zero + title: Pix2Pix Zero title: Stable Diffusion - local: api/pipelines/stable_diffusion_2 title: Stable Diffusion 2 diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx index 5d3fb77c7aad..3c01c9d3c0e0 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx @@ -33,6 +33,7 @@ For more details about how Stable Diffusion works and how it differs from the ba | [StableDiffusionUpscalePipeline](./upscale) | **Experimental** – *Text-Guided Image Super-Resolution * | | Coming soon | [StableDiffusionLatentUpscalePipeline](./latent_upscale) | **Experimental** – *Text-Guided Image Super-Resolution * | | Coming soon | [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** – *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix) +| [StableDiffusionPix2PixZeroPipeline](./pix2pix_zero) | **Experimental** – *Text-Based Image Editing * | | [Zero-shot Image-to-Image Translation](https://arxiv.org/abs/2302.03027) diff --git a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx new file mode 100644 index 000000000000..eafc6b77ec2f --- /dev/null +++ b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx @@ -0,0 +1,92 @@ + + +# Zero-shot Image-to-Image Translation + +## Overview + +[Zero-shot Image-to-Image Translation](https://arxiv.org/abs/2302.03027) by Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, and Jun-Yan Zhu. + +The abstract of the paper is the following: + +*Large-scale text-to-image generative models have shown their remarkable ability to synthesize diverse and high-quality images. However, it is still challenging to directly apply these models for editing real images for two reasons. First, it is hard for users to come up with a perfect text prompt that accurately describes every visual detail in the input image. Second, while existing models can introduce desirable changes in certain regions, they often dramatically alter the input content and introduce unexpected changes in unwanted regions. In this work, we propose pix2pix-zero, an image-to-image translation method that can preserve the content of the original image without manual prompting. We first automatically discover editing directions that reflect desired edits in the text embedding space. To preserve the general content structure after editing, we further propose cross-attention guidance, which aims to retain the cross-attention maps of the input image throughout the diffusion process. In addition, our method does not need additional training for these edits and can directly use the existing pre-trained text-to-image diffusion model. We conduct extensive experiments and show that our method outperforms existing and concurrent works for both real and synthetic image editing.* + +Resources: + +* [Project Page](https://pix2pixzero.github.io/). +* [Paper](https://arxiv.org/abs/2302.03027). +* [Original Code](https://github.com/pix2pixzero/pix2pix-zero). + +## Tips + +* The pipeline exposes two arguments namely `source_embedding_path` and `target_embedding_path` +that let you control the direction of the semantic edits in the final image to be generated. Let's say, +you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect +this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to +`source_embedding_path` and "dog" to `target_embedding_path`. +* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking +the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gough". +* If you wanted to reverse the direction i.e., "dog -> cat", then it's recommended to: + * Swap the `source_embedding_path` and `target_embedding_path`. + * Change the input prompt to include "dog". + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [StableDiffusionPix2PixZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py) | *Text-Based Image Editing* | [🤗 Space] (soon) | + + + +## Usage example + +```python +import requests +import torch + +from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline + + +def download(embedding_url, local_filepath): + r = requests.get(embedding_url) + with open(local_filepath, "wb") as f: + f.write(r.content) + + +model_ckpt = "CompVis/stable-diffusion-v1-4" +pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + model_ckpt, conditions_input_image=False, torch_dtype=torch.float16 +) +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.to("cuda") + +prompt = "a high resolution painting of a cat in the style of van gough" +source_embedding_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" +target_embedding_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" + +for url in [source_embedding_url, target_embedding_url]: + download(url, url.split("/")[-1]) + +images = pipeline( + prompt, + source_embedding_path=source_embedding_url.split("/")[-1], + target_embedding_path=target_embedding_url.split("/")[-1], + num_inference_steps=50, + cross_attention_guidance_amount=0.15, +).images +images[0].save("edited_image_dog.png") +``` + +## StableDiffusionPix2PixZeroPipeline +[[autodoc]] StableDiffusionPix2PixZeroPipeline + - __call__ + - all From 7bd9a6dfde6d802dd3a1787b5530b00a4a9a28c6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 09:19:01 +0530 Subject: [PATCH 41/63] first test. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 +- .../test_stable_diffusion_pix2pix_zero.py | 397 ++++++++++++++++++ 2 files changed, 399 insertions(+), 2 deletions(-) create mode 100644 tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 077e9463f607..f7543ea7d59a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -551,7 +551,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -579,7 +579,7 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image`): + image (`PIL.Image.Image`, *optional*): `Image`, or tensor representing an image batch which will be used for conditioning. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py new file mode 100644 index 000000000000..d5b9c1b72df4 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -0,0 +1,397 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest +import requests + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerAncestralDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPix2PixZeroPipeline, + UNet2DConditionModel, +) +from diffusers.utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionPix2PixZeroPipeline + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=8, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = DDIMScheduler() + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def download_from_url(self, embedding_url, local_filepath): + r = requests.get(embedding_url) + with open(local_filepath, "wb") as f: + f.write(r.content) + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + + src_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" + tgt_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" + + for url in [src_emb_url, tgt_emb_url]: + self.download_from_url(url, url.split["/"][-1]) + + generator = torch.manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "cross_attention_guidance_amount": 0.15, + "source_embedding_file": src_emb_url.split("/")[-1], + "target_embedding_file": tgt_emb_url.split("/")[-1], + "output_type": "numpy", + } + return inputs + + def test_stable_diffusion_pix2pix_zero_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 32, 32, 3) + print(image_slice.flatten()) + # expected_slice = np.array([0.7318, 0.3723, 0.4662, 0.623, 0.5770, 0.5014, 0.4281, 0.5550, 0.4813]) + + # assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + +# def test_stable_diffusion_pix2pix_negative_prompt(self): +# device = "cpu" # ensure determinism for the device-dependent torch.Generator +# components = self.get_dummy_components() +# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) +# sd_pipe = sd_pipe.to(device) +# sd_pipe.set_progress_bar_config(disable=None) + +# inputs = self.get_dummy_inputs(device) +# negative_prompt = "french fries" +# output = sd_pipe(**inputs, negative_prompt=negative_prompt) +# image = output.images +# image_slice = image[0, -3:, -3:, -1] + +# assert image.shape == (1, 32, 32, 3) +# expected_slice = np.array([0.7323, 0.3688, 0.4611, 0.6255, 0.5746, 0.5017, 0.433, 0.5553, 0.4827]) + +# assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + +# def test_stable_diffusion_pix2pix_multiple_init_images(self): +# device = "cpu" # ensure determinism for the device-dependent torch.Generator +# components = self.get_dummy_components() +# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) +# sd_pipe = sd_pipe.to(device) +# sd_pipe.set_progress_bar_config(disable=None) + +# inputs = self.get_dummy_inputs(device) +# inputs["prompt"] = [inputs["prompt"]] * 2 + +# image = np.array(inputs["image"]).astype(np.float32) / 255.0 +# image = torch.from_numpy(image).unsqueeze(0).to(device) +# image = image.permute(0, 3, 1, 2) +# inputs["image"] = image.repeat(2, 1, 1, 1) + +# image = sd_pipe(**inputs).images +# image_slice = image[-1, -3:, -3:, -1] + +# assert image.shape == (2, 32, 32, 3) +# expected_slice = np.array([0.606, 0.5712, 0.5099, 0.598, 0.5805, 0.7205, 0.6793, 0.554, 0.5607]) + +# assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + +# def test_stable_diffusion_pix2pix_euler(self): +# device = "cpu" # ensure determinism for the device-dependent torch.Generator +# components = self.get_dummy_components() +# components["scheduler"] = EulerAncestralDiscreteScheduler( +# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" +# ) +# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) +# sd_pipe = sd_pipe.to(device) +# sd_pipe.set_progress_bar_config(disable=None) + +# inputs = self.get_dummy_inputs(device) +# image = sd_pipe(**inputs).images +# image_slice = image[0, -3:, -3:, -1] + +# slice = [round(x, 4) for x in image_slice.flatten().tolist()] +# print(",".join([str(x) for x in slice])) + +# assert image.shape == (1, 32, 32, 3) +# expected_slice = np.array([0.726, 0.3902, 0.4868, 0.585, 0.5672, 0.511, 0.3906, 0.551, 0.4846]) + +# assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + +# def test_stable_diffusion_pix2pix_num_images_per_prompt(self): +# device = "cpu" # ensure determinism for the device-dependent torch.Generator +# components = self.get_dummy_components() +# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) +# sd_pipe = sd_pipe.to(device) +# sd_pipe.set_progress_bar_config(disable=None) + +# # test num_images_per_prompt=1 (default) +# inputs = self.get_dummy_inputs(device) +# images = sd_pipe(**inputs).images + +# assert images.shape == (1, 32, 32, 3) + +# # test num_images_per_prompt=1 (default) for batch of prompts +# batch_size = 2 +# inputs = self.get_dummy_inputs(device) +# inputs["prompt"] = [inputs["prompt"]] * batch_size +# images = sd_pipe(**inputs).images + +# assert images.shape == (batch_size, 32, 32, 3) + +# # test num_images_per_prompt for single prompt +# num_images_per_prompt = 2 +# inputs = self.get_dummy_inputs(device) +# images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + +# assert images.shape == (num_images_per_prompt, 32, 32, 3) + +# # test num_images_per_prompt for batch of prompts +# batch_size = 2 +# inputs = self.get_dummy_inputs(device) +# inputs["prompt"] = [inputs["prompt"]] * batch_size +# images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + +# assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + + +# @slow +# @require_torch_gpu +# class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase): +# def tearDown(self): +# super().tearDown() +# gc.collect() +# torch.cuda.empty_cache() + +# def get_inputs(self, seed=0): +# generator = torch.manual_seed(seed) +# image = load_image( +# "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg" +# ) +# inputs = { +# "prompt": "turn him into a cyborg", +# "image": image, +# "generator": generator, +# "num_inference_steps": 3, +# "guidance_scale": 7.5, +# "image_guidance_scale": 1.0, +# "output_type": "numpy", +# } +# return inputs + +# def test_stable_diffusion_pix2pix_default(self): +# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( +# "timbrooks/instruct-pix2pix", safety_checker=None +# ) +# pipe.to(torch_device) +# pipe.set_progress_bar_config(disable=None) +# pipe.enable_attention_slicing() + +# inputs = self.get_inputs() +# image = pipe(**inputs).images +# image_slice = image[0, -3:, -3:, -1].flatten() + +# assert image.shape == (1, 512, 512, 3) +# expected_slice = np.array([0.5902, 0.6015, 0.6027, 0.5983, 0.6092, 0.6061, 0.5765, 0.5785, 0.5555]) + +# assert np.abs(expected_slice - image_slice).max() < 1e-3 + +# def test_stable_diffusion_pix2pix_k_lms(self): +# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( +# "timbrooks/instruct-pix2pix", safety_checker=None +# ) +# pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) +# pipe.to(torch_device) +# pipe.set_progress_bar_config(disable=None) +# pipe.enable_attention_slicing() + +# inputs = self.get_inputs() +# image = pipe(**inputs).images +# image_slice = image[0, -3:, -3:, -1].flatten() + +# assert image.shape == (1, 512, 512, 3) +# expected_slice = np.array([0.6578, 0.6817, 0.6972, 0.6761, 0.6856, 0.6916, 0.6428, 0.6516, 0.6301]) + +# assert np.abs(expected_slice - image_slice).max() < 1e-3 + +# def test_stable_diffusion_pix2pix_ddim(self): +# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( +# "timbrooks/instruct-pix2pix", safety_checker=None +# ) +# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) +# pipe.to(torch_device) +# pipe.set_progress_bar_config(disable=None) +# pipe.enable_attention_slicing() + +# inputs = self.get_inputs() +# image = pipe(**inputs).images +# image_slice = image[0, -3:, -3:, -1].flatten() + +# assert image.shape == (1, 512, 512, 3) +# expected_slice = np.array([0.3828, 0.3834, 0.3818, 0.3792, 0.3865, 0.3752, 0.3792, 0.3847, 0.3753]) + +# assert np.abs(expected_slice - image_slice).max() < 1e-3 + +# def test_stable_diffusion_pix2pix_intermediate_state(self): +# number_of_steps = 0 + +# def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: +# callback_fn.has_been_called = True +# nonlocal number_of_steps +# number_of_steps += 1 +# if step == 1: +# latents = latents.detach().cpu().numpy() +# assert latents.shape == (1, 4, 64, 64) +# latents_slice = latents[0, -3:, -3:, -1] +# expected_slice = np.array([-0.2463, -0.4644, -0.9756, 1.5176, 1.4414, 0.7866, 0.9897, 0.8521, 0.7983]) + +# assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 +# elif step == 2: +# latents = latents.detach().cpu().numpy() +# assert latents.shape == (1, 4, 64, 64) +# latents_slice = latents[0, -3:, -3:, -1] +# expected_slice = np.array([-0.2644, -0.4626, -0.9653, 1.5176, 1.4551, 0.7686, 0.9805, 0.8452, 0.8115]) + +# assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 + +# callback_fn.has_been_called = False + +# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( +# "timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=torch.float16 +# ) +# pipe = pipe.to(torch_device) +# pipe.set_progress_bar_config(disable=None) +# pipe.enable_attention_slicing() + +# inputs = self.get_inputs() +# pipe(**inputs, callback=callback_fn, callback_steps=1) +# assert callback_fn.has_been_called +# assert number_of_steps == 3 + +# def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): +# torch.cuda.empty_cache() +# torch.cuda.reset_max_memory_allocated() +# torch.cuda.reset_peak_memory_stats() + +# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( +# "timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=torch.float16 +# ) +# pipe = pipe.to(torch_device) +# pipe.set_progress_bar_config(disable=None) +# pipe.enable_attention_slicing(1) +# pipe.enable_sequential_cpu_offload() + +# inputs = self.get_inputs() +# _ = pipe(**inputs) + +# mem_bytes = torch.cuda.max_memory_allocated() +# # make sure that less than 2.2 GB is allocated +# assert mem_bytes < 2.2 * 10**9 + +# def test_stable_diffusion_pix2pix_pipeline_multiple_of_8(self): +# inputs = self.get_inputs() +# # resize to resolution that is divisible by 8 but not 16 or 32 +# inputs["image"] = inputs["image"].resize((504, 504)) + +# model_id = "timbrooks/instruct-pix2pix" +# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( +# model_id, +# safety_checker=None, +# ) +# pipe.to(torch_device) +# pipe.set_progress_bar_config(disable=None) +# pipe.enable_attention_slicing() + +# output = pipe(**inputs) +# image = output.images[0] + +# image_slice = image[255:258, 383:386, -1] + +# assert image.shape == (504, 504, 3) +# expected_slice = np.array([0.2726, 0.2529, 0.2664, 0.2655, 0.2641, 0.2642, 0.2591, 0.2649, 0.2590]) + +# assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 From 164aef43a580aa7dc52469bb62bede973a04046c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 09:21:58 +0530 Subject: [PATCH 42/63] fix: url call. --- .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index d5b9c1b72df4..7474c189bc05 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -106,7 +106,7 @@ def get_dummy_inputs(self, device, seed=0): tgt_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" for url in [src_emb_url, tgt_emb_url]: - self.download_from_url(url, url.split["/"][-1]) + self.download_from_url(url, url.split("/")[-1]) generator = torch.manual_seed(seed) From cc1efb2fffb9482fdc03025f04aa48ff5d1929d0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 09:24:18 +0530 Subject: [PATCH 43/63] fix: argument call. --- .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 7474c189bc05..a52b06ba1b51 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -117,8 +117,8 @@ def get_dummy_inputs(self, device, seed=0): "num_inference_steps": 2, "guidance_scale": 6.0, "cross_attention_guidance_amount": 0.15, - "source_embedding_file": src_emb_url.split("/")[-1], - "target_embedding_file": tgt_emb_url.split("/")[-1], + "source_embedding_path": src_emb_url.split("/")[-1], + "target_embedding_path": tgt_emb_url.split("/")[-1], "output_type": "numpy", } return inputs From 17dda7c00cfbcd8d17989daaa6b47830ff0c143c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 09:27:30 +0530 Subject: [PATCH 44/63] remove image conditioning for now. --- .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index a52b06ba1b51..acd01c470124 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -98,10 +98,6 @@ def download_from_url(self, embedding_url, local_filepath): f.write(r.content) def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) - image = image.cpu().permute(0, 2, 3, 1)[0] - image = Image.fromarray(np.uint8(image)).convert("RGB") - src_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" tgt_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" @@ -112,7 +108,6 @@ def get_dummy_inputs(self, device, seed=0): inputs = { "prompt": "A painting of a squirrel eating a burger", - "image": image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, From 6a3091ebdcea2e5c5657a98bcf7bea56ab172df2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 10:24:05 +0530 Subject: [PATCH 45/63] =?UTF-8?q?=F0=9F=9A=A8=20add:=20fast=20tests.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_stable_diffusion_pix2pix_zero.py | 164 +++++++----------- 1 file changed, 61 insertions(+), 103 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index acd01c470124..6dd1048818ea 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -13,27 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc -import random import unittest -import requests import numpy as np +import requests import torch -from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, StableDiffusionPix2PixZeroPipeline, UNet2DConditionModel, ) -from diffusers.utils import floats_tensor, load_image, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu from ...test_pipelines_common import PipelineTesterMixin @@ -50,7 +43,7 @@ def get_dummy_components(self): block_out_channels=(32, 64), layers_per_block=2, sample_size=32, - in_channels=8, + in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), @@ -97,15 +90,15 @@ def download_from_url(self, embedding_url, local_filepath): with open(local_filepath, "wb") as f: f.write(r.content) - def get_dummy_inputs(self, device, seed=0): - src_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" - tgt_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" + def get_dummy_inputs(self, seed=0): + src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt" + tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt" for url in [src_emb_url, tgt_emb_url]: self.download_from_url(url, url.split("/")[-1]) - + generator = torch.manual_seed(seed) - + inputs = { "prompt": "A painting of a squirrel eating a burger", "generator": generator, @@ -125,113 +118,78 @@ def test_stable_diffusion_pix2pix_zero_default_case(self): sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs() image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 32, 32, 3) - print(image_slice.flatten()) - # expected_slice = np.array([0.7318, 0.3723, 0.4662, 0.623, 0.5770, 0.5014, 0.4281, 0.5550, 0.4813]) - - # assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 - -# def test_stable_diffusion_pix2pix_negative_prompt(self): -# device = "cpu" # ensure determinism for the device-dependent torch.Generator -# components = self.get_dummy_components() -# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) -# sd_pipe = sd_pipe.to(device) -# sd_pipe.set_progress_bar_config(disable=None) - -# inputs = self.get_dummy_inputs(device) -# negative_prompt = "french fries" -# output = sd_pipe(**inputs, negative_prompt=negative_prompt) -# image = output.images -# image_slice = image[0, -3:, -3:, -1] - -# assert image.shape == (1, 32, 32, 3) -# expected_slice = np.array([0.7323, 0.3688, 0.4611, 0.6255, 0.5746, 0.5017, 0.433, 0.5553, 0.4827]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5184, 0.503, 0.4917, 0.4022, 0.3455, 0.464, 0.5324, 0.5323, 0.4894]) -# assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 -# def test_stable_diffusion_pix2pix_multiple_init_images(self): -# device = "cpu" # ensure determinism for the device-dependent torch.Generator -# components = self.get_dummy_components() -# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) -# sd_pipe = sd_pipe.to(device) -# sd_pipe.set_progress_bar_config(disable=None) - -# inputs = self.get_dummy_inputs(device) -# inputs["prompt"] = [inputs["prompt"]] * 2 - -# image = np.array(inputs["image"]).astype(np.float32) / 255.0 -# image = torch.from_numpy(image).unsqueeze(0).to(device) -# image = image.permute(0, 3, 1, 2) -# inputs["image"] = image.repeat(2, 1, 1, 1) - -# image = sd_pipe(**inputs).images -# image_slice = image[-1, -3:, -3:, -1] - -# assert image.shape == (2, 32, 32, 3) -# expected_slice = np.array([0.606, 0.5712, 0.5099, 0.598, 0.5805, 0.7205, 0.6793, 0.554, 0.5607]) - -# assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_pix2pix_zero_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) -# def test_stable_diffusion_pix2pix_euler(self): -# device = "cpu" # ensure determinism for the device-dependent torch.Generator -# components = self.get_dummy_components() -# components["scheduler"] = EulerAncestralDiscreteScheduler( -# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" -# ) -# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) -# sd_pipe = sd_pipe.to(device) -# sd_pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs() + negative_prompt = "french fries" + output = sd_pipe(**inputs, negative_prompt=negative_prompt) + image = output.images + image_slice = image[0, -3:, -3:, -1] -# inputs = self.get_dummy_inputs(device) -# image = sd_pipe(**inputs).images -# image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5464, 0.5072, 0.5012, 0.4124, 0.3624, 0.466, 0.5413, 0.5468, 0.4927]) -# slice = [round(x, 4) for x in image_slice.flatten().tolist()] -# print(",".join([str(x) for x in slice])) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 -# assert image.shape == (1, 32, 32, 3) -# expected_slice = np.array([0.726, 0.3902, 0.4868, 0.585, 0.5672, 0.511, 0.3906, 0.551, 0.4846]) + def test_stable_diffusion_pix2pix_zero_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = EulerAncestralDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) -# assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + inputs = self.get_dummy_inputs() + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] -# def test_stable_diffusion_pix2pix_num_images_per_prompt(self): -# device = "cpu" # ensure determinism for the device-dependent torch.Generator -# components = self.get_dummy_components() -# sd_pipe = StableDiffusionInstructPix2PixPipeline(**components) -# sd_pipe = sd_pipe.to(device) -# sd_pipe.set_progress_bar_config(disable=None) + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.5114, 0.5051, 0.5222, 0.5279, 0.5037, 0.5156, 0.4604, 0.4966, 0.504]) -# # test num_images_per_prompt=1 (default) -# inputs = self.get_dummy_inputs(device) -# images = sd_pipe(**inputs).images + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 -# assert images.shape == (1, 32, 32, 3) + def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) -# # test num_images_per_prompt=1 (default) for batch of prompts -# batch_size = 2 -# inputs = self.get_dummy_inputs(device) -# inputs["prompt"] = [inputs["prompt"]] * batch_size -# images = sd_pipe(**inputs).images + # test num_images_per_prompt=1 (default) + inputs = self.get_dummy_inputs() + images = sd_pipe(**inputs).images -# assert images.shape == (batch_size, 32, 32, 3) + assert images.shape == (1, 64, 64, 3) -# # test num_images_per_prompt for single prompt -# num_images_per_prompt = 2 -# inputs = self.get_dummy_inputs(device) -# images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + # test num_images_per_prompt=2 for a single prompt + num_images_per_prompt = 2 + inputs = self.get_dummy_inputs() + images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images -# assert images.shape == (num_images_per_prompt, 32, 32, 3) + assert images.shape == (num_images_per_prompt, 32, 32, 3) -# # test num_images_per_prompt for batch of prompts -# batch_size = 2 -# inputs = self.get_dummy_inputs(device) -# inputs["prompt"] = [inputs["prompt"]] * batch_size -# images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + # test num_images_per_prompt for batch of prompts + batch_size = 2 + inputs = self.get_dummy_inputs() + inputs["prompt"] = [inputs["prompt"]] * batch_size + images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images -# assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) # @slow From df99c5341af97e946d0f5d25b0b0888196411c45 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 10:54:54 +0530 Subject: [PATCH 46/63] explicit placement of the xa attn weights. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index f7543ea7d59a..690aaee7f7ea 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -157,11 +157,11 @@ def __call__( if self.is_pix2pix_zero and timestep is not None: # new bookkeeping to save the attention weights. if loss is None: - self.xa_map[timestep.item()] = attention_probs + self.xa_map[timestep.item()] = attention_probs.detach().cpu() # compute loss elif loss is not None: prev_attn_probs = self.xa_map.pop(timestep.item()) - loss.compute_loss(attention_probs, prev_attn_probs.detach()) + loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) From 9a58071ea5505d12997302f25810834dfe7fc731 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 11:20:22 +0530 Subject: [PATCH 47/63] =?UTF-8?q?add:=20slow=20tests=20=F0=9F=90=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_stable_diffusion_pix2pix_zero.py | 299 ++++++++---------- 1 file changed, 137 insertions(+), 162 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 6dd1048818ea..ab0d84b57cfe 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import unittest import numpy as np @@ -24,9 +25,12 @@ AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler, + LMSDiscreteScheduler, StableDiffusionPix2PixZeroPipeline, UNet2DConditionModel, ) +from diffusers.utils import slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu from ...test_pipelines_common import PipelineTesterMixin @@ -34,6 +38,12 @@ torch.backends.cuda.matmul.allow_tf32 = False +def download_from_url(embedding_url, local_filepath): + r = requests.get(embedding_url) + with open(local_filepath, "wb") as f: + f.write(r.content) + + class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPix2PixZeroPipeline @@ -85,17 +95,12 @@ def get_dummy_components(self): } return components - def download_from_url(self, embedding_url, local_filepath): - r = requests.get(embedding_url) - with open(local_filepath, "wb") as f: - f.write(r.content) - def get_dummy_inputs(self, seed=0): src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt" tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt" for url in [src_emb_url, tgt_emb_url]: - self.download_from_url(url, url.split("/")[-1]) + download_from_url(url, url.split("/")[-1]) generator = torch.manual_seed(seed) @@ -192,159 +197,129 @@ def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self): assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) -# @slow -# @require_torch_gpu -# class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase): -# def tearDown(self): -# super().tearDown() -# gc.collect() -# torch.cuda.empty_cache() - -# def get_inputs(self, seed=0): -# generator = torch.manual_seed(seed) -# image = load_image( -# "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg" -# ) -# inputs = { -# "prompt": "turn him into a cyborg", -# "image": image, -# "generator": generator, -# "num_inference_steps": 3, -# "guidance_scale": 7.5, -# "image_guidance_scale": 1.0, -# "output_type": "numpy", -# } -# return inputs - -# def test_stable_diffusion_pix2pix_default(self): -# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( -# "timbrooks/instruct-pix2pix", safety_checker=None -# ) -# pipe.to(torch_device) -# pipe.set_progress_bar_config(disable=None) -# pipe.enable_attention_slicing() - -# inputs = self.get_inputs() -# image = pipe(**inputs).images -# image_slice = image[0, -3:, -3:, -1].flatten() - -# assert image.shape == (1, 512, 512, 3) -# expected_slice = np.array([0.5902, 0.6015, 0.6027, 0.5983, 0.6092, 0.6061, 0.5765, 0.5785, 0.5555]) - -# assert np.abs(expected_slice - image_slice).max() < 1e-3 - -# def test_stable_diffusion_pix2pix_k_lms(self): -# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( -# "timbrooks/instruct-pix2pix", safety_checker=None -# ) -# pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) -# pipe.to(torch_device) -# pipe.set_progress_bar_config(disable=None) -# pipe.enable_attention_slicing() - -# inputs = self.get_inputs() -# image = pipe(**inputs).images -# image_slice = image[0, -3:, -3:, -1].flatten() - -# assert image.shape == (1, 512, 512, 3) -# expected_slice = np.array([0.6578, 0.6817, 0.6972, 0.6761, 0.6856, 0.6916, 0.6428, 0.6516, 0.6301]) - -# assert np.abs(expected_slice - image_slice).max() < 1e-3 - -# def test_stable_diffusion_pix2pix_ddim(self): -# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( -# "timbrooks/instruct-pix2pix", safety_checker=None -# ) -# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) -# pipe.to(torch_device) -# pipe.set_progress_bar_config(disable=None) -# pipe.enable_attention_slicing() - -# inputs = self.get_inputs() -# image = pipe(**inputs).images -# image_slice = image[0, -3:, -3:, -1].flatten() - -# assert image.shape == (1, 512, 512, 3) -# expected_slice = np.array([0.3828, 0.3834, 0.3818, 0.3792, 0.3865, 0.3752, 0.3792, 0.3847, 0.3753]) - -# assert np.abs(expected_slice - image_slice).max() < 1e-3 - -# def test_stable_diffusion_pix2pix_intermediate_state(self): -# number_of_steps = 0 - -# def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: -# callback_fn.has_been_called = True -# nonlocal number_of_steps -# number_of_steps += 1 -# if step == 1: -# latents = latents.detach().cpu().numpy() -# assert latents.shape == (1, 4, 64, 64) -# latents_slice = latents[0, -3:, -3:, -1] -# expected_slice = np.array([-0.2463, -0.4644, -0.9756, 1.5176, 1.4414, 0.7866, 0.9897, 0.8521, 0.7983]) - -# assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 -# elif step == 2: -# latents = latents.detach().cpu().numpy() -# assert latents.shape == (1, 4, 64, 64) -# latents_slice = latents[0, -3:, -3:, -1] -# expected_slice = np.array([-0.2644, -0.4626, -0.9653, 1.5176, 1.4551, 0.7686, 0.9805, 0.8452, 0.8115]) - -# assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 - -# callback_fn.has_been_called = False - -# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( -# "timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=torch.float16 -# ) -# pipe = pipe.to(torch_device) -# pipe.set_progress_bar_config(disable=None) -# pipe.enable_attention_slicing() - -# inputs = self.get_inputs() -# pipe(**inputs, callback=callback_fn, callback_steps=1) -# assert callback_fn.has_been_called -# assert number_of_steps == 3 - -# def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): -# torch.cuda.empty_cache() -# torch.cuda.reset_max_memory_allocated() -# torch.cuda.reset_peak_memory_stats() - -# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( -# "timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=torch.float16 -# ) -# pipe = pipe.to(torch_device) -# pipe.set_progress_bar_config(disable=None) -# pipe.enable_attention_slicing(1) -# pipe.enable_sequential_cpu_offload() - -# inputs = self.get_inputs() -# _ = pipe(**inputs) - -# mem_bytes = torch.cuda.max_memory_allocated() -# # make sure that less than 2.2 GB is allocated -# assert mem_bytes < 2.2 * 10**9 - -# def test_stable_diffusion_pix2pix_pipeline_multiple_of_8(self): -# inputs = self.get_inputs() -# # resize to resolution that is divisible by 8 but not 16 or 32 -# inputs["image"] = inputs["image"].resize((504, 504)) - -# model_id = "timbrooks/instruct-pix2pix" -# pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( -# model_id, -# safety_checker=None, -# ) -# pipe.to(torch_device) -# pipe.set_progress_bar_config(disable=None) -# pipe.enable_attention_slicing() - -# output = pipe(**inputs) -# image = output.images[0] - -# image_slice = image[255:258, 383:386, -1] - -# assert image.shape == (504, 504, 3) -# expected_slice = np.array([0.2726, 0.2529, 0.2664, 0.2655, 0.2641, 0.2642, 0.2591, 0.2649, 0.2590]) - -# assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 +@slow +@require_torch_gpu +class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, seed=0): + generator = torch.manual_seed(seed) + + src_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" + tgt_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" + + for url in [src_emb_url, tgt_emb_url]: + download_from_url(url, url.split("/")[-1]) + + inputs = { + "prompt": "turn him into a cyborg", + "generator": generator, + "num_inference_steps": 3, + "guidance_scale": 7.5, + "cross_attention_guidance_amount": 0.15, + "source_embedding_path": src_emb_url.split("/")[-1], + "target_embedding_path": tgt_emb_url.split("/")[-1], + "output_type": "numpy", + } + return inputs + + def test_stable_diffusion_pix2pix_zero_default(self): + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs() + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.4705, 0.4771, 0.4832, 0.4783, 0.4495, 0.447, 0.4658, 0.4568, 0.438]) + + assert np.abs(expected_slice - image_slice).max() < 1e-3 + + def test_stable_diffusion_pix2pix_zero_k_lms(self): + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs() + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.6514, 0.5571, 0.5244, 0.5591, 0.4998, 0.4834, 0.502, 0.468, 0.4663]) + + assert np.abs(expected_slice - image_slice).max() < 1e-3 + + def test_stable_diffusion_pix2pix_zero_intermediate_state(self): + number_of_steps = 0 + + def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 1: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.5176, 0.0669, -0.1963, -0.1653, -0.7856, -0.2871, -0.5562, -0.0096, -0.012] + ) + + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 + elif step == 2: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.5127, 0.0613, -0.1937, -0.1622, -0.7856, -0.2849, -0.5601, -0.0111, -0.0137] + ) + + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 + + callback_fn.has_been_called = False + + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs() + pipe(**inputs, callback=callback_fn, callback_steps=1) + assert callback_fn.has_been_called + assert number_of_steps == 3 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + inputs = self.get_inputs() + _ = pipe(**inputs) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 8.2 GB is allocated + assert mem_bytes < 8.2 * 10**9 From 9df89992b8ce7b34e748eb692d325f4677a5a122 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 11:28:34 +0530 Subject: [PATCH 48/63] fix: tests. --- .../test_stable_diffusion_pix2pix_zero.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index ab0d84b57cfe..ad8d4b176476 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -95,7 +95,7 @@ def get_dummy_components(self): } return components - def get_dummy_inputs(self, seed=0): + def get_dummy_inputs(self, device, seed=0): src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt" tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt" @@ -123,7 +123,7 @@ def test_stable_diffusion_pix2pix_zero_default_case(self): sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(device) image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) @@ -138,7 +138,7 @@ def test_stable_diffusion_pix2pix_zero_negative_prompt(self): sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(device) negative_prompt = "french fries" output = sd_pipe(**inputs, negative_prompt=negative_prompt) image = output.images @@ -159,11 +159,11 @@ def test_stable_diffusion_pix2pix_zero_euler(self): sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(device) image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 32, 32, 3) + assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5114, 0.5051, 0.5222, 0.5279, 0.5037, 0.5156, 0.4604, 0.4966, 0.504]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @@ -176,25 +176,25 @@ def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self): sd_pipe.set_progress_bar_config(disable=None) # test num_images_per_prompt=1 (default) - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(device) images = sd_pipe(**inputs).images assert images.shape == (1, 64, 64, 3) # test num_images_per_prompt=2 for a single prompt num_images_per_prompt = 2 - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(device) images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images - assert images.shape == (num_images_per_prompt, 32, 32, 3) + assert images.shape == (num_images_per_prompt, 64, 64, 3) # test num_images_per_prompt for batch of prompts batch_size = 2 - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(device) inputs["prompt"] = [inputs["prompt"]] * batch_size images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images - assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) @slow From c20870fbf470bd7e7bcc17b8c3a1c58d37d6757e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 11:35:19 +0530 Subject: [PATCH 49/63] edited direction embedding should be on the same device as prompt_embeds. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 690aaee7f7ea..970b046090d1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -766,7 +766,7 @@ def __call__( callback(i, t, latents) # 8. Compute the edit directions. - edit_direction = construct_direction(source_embedding_path, target_embedding_path) + edit_direction = construct_direction(source_embedding_path, target_embedding_path).to(prompt_embeds.device) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() From 731267bec4baf45e8182a452f524d2cfdd59c4df Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 11:46:28 +0530 Subject: [PATCH 50/63] debugging message. --- tests/test_pipelines_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index a1d3122f875c..e4297fc06a79 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -245,7 +245,7 @@ def _test_inference_batch_single_identical( if self.pipeline_class.__name__ != "DanceDiffusionPipeline": batched_inputs["output_type"] = "np" - + print(f"From common tests: {batched_inputs}") output_batch = pipe(**batched_inputs) assert output_batch[0].shape[0] == batch_size From 87d5c1588e38b0c98ec6c41ac472335bfd1f2a49 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 11:52:18 +0530 Subject: [PATCH 51/63] debugging. --- tests/test_pipelines_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index e4297fc06a79..ddc205dc7431 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -245,7 +245,7 @@ def _test_inference_batch_single_identical( if self.pipeline_class.__name__ != "DanceDiffusionPipeline": batched_inputs["output_type"] = "np" - print(f"From common tests: {batched_inputs}") + print(f"From common tests: {len(batched_inputs)}") output_batch = pipe(**batched_inputs) assert output_batch[0].shape[0] == batch_size From d81e110bc5ea9bbeaf1f0c543e45a25b90936408 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 11:56:57 +0530 Subject: [PATCH 52/63] add pix2pix zero pipeline for a non-deterministic test. --- tests/test_pipelines_common.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index ddc205dc7431..ba0bb3869ac0 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -192,10 +192,16 @@ def test_inference_batch_single_identical(self): def _test_inference_batch_single_identical( self, test_max_difference=None, test_mean_pixel_difference=None, relax_max_difference=False ): - if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]: + if self.pipeline_class.__name__ in [ + "CycleDiffusionPipeline", + "RePaintPipeline", + "StableDiffusionPix2PixZeroPipeline", + ]: # RePaint can hardly be made deterministic since the scheduler is currently always # nondeterministic # CycleDiffusion is also slightly nondeterministic + # There's a training loop inside Pix2PixZero and is guided by edit directions. This is + # why the slight non-determinism. return if test_max_difference is None: @@ -245,7 +251,7 @@ def _test_inference_batch_single_identical( if self.pipeline_class.__name__ != "DanceDiffusionPipeline": batched_inputs["output_type"] = "np" - print(f"From common tests: {len(batched_inputs)}") + output_batch = pipe(**batched_inputs) assert output_batch[0].shape[0] == batch_size From 48e8f8ce7e74194443ce12632d66745a15bf8007 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 12:58:30 +0530 Subject: [PATCH 53/63] debugging/ --- src/diffusers/pipelines/pipeline_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b6cf92abfcdf..f7e377e70ff8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -802,6 +802,9 @@ def components(self) -> Dict[str, Any]: components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } + print(f"Components: {components.keys()}") + print(f"Optional parameters: {optional_parameters}") + print(f"Expected modules: {expected_modules}") if set(components.keys()) != expected_modules: raise ValueError( From 522e8820b281bbd9af0189020136d8a2245f0756 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 13:06:05 +0530 Subject: [PATCH 54/63] remove debugging message. --- src/diffusers/pipelines/pipeline_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f7e377e70ff8..b6cf92abfcdf 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -802,9 +802,6 @@ def components(self) -> Dict[str, Any]: components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } - print(f"Components: {components.keys()}") - print(f"Optional parameters: {optional_parameters}") - print(f"Expected modules: {expected_modules}") if set(components.keys()) != expected_modules: raise ValueError( From 8884e10dcc4c6c1f1c1ff4dc82c79e3306d91e27 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 13:17:48 +0530 Subject: [PATCH 55/63] make caption generation _ --- .../pipeline_stable_diffusion_pix2pix_zero.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 970b046090d1..c6bc789aab03 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -252,8 +252,8 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, - captioner_processor=captioner_processor, - captioner=captioner, + _captioner_processor=captioner_processor, + _captioner=captioner, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, @@ -261,8 +261,8 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.conditions_input_image = conditions_input_image self.register_to_config( - captioner=captioner, - captioner_processor=captioner_processor, + _captioner=captioner, + _captioner_processor=captioner_processor, requires_safety_checker=requires_safety_checker, ) @@ -670,7 +670,7 @@ def __call__( # 2. Generate a caption for the input image if we are conditioning the # pipeline based on some input image. if self.conditions_input_image: - caption, preprocessed_image = generate_caption(image, self.captioner, self.captioner_processor) + caption, preprocessed_image = generate_caption(image, self._captioner, self._captioner_processor) height, width = preprocessed_image.shape[-2:] prompt = caption logger.info(f"Generated caption for the input image: {caption}.") From d23357fb662321677ff3accfb7d02246c67e9f17 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 16:47:17 +0530 Subject: [PATCH 56/63] address comments (part I). --- .../pipelines/stable_diffusion/__init__.py | 2 +- .../pipeline_stable_diffusion_pix2pix_zero.py | 47 ++++++++++--------- src/diffusers/utils/import_utils.py | 6 +-- .../test_stable_diffusion_pix2pix_zero.py | 14 ++++-- 4 files changed, 38 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 634de711806e..90cff3142ad4 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -44,7 +44,6 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline - from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline @@ -67,6 +66,7 @@ class StableDiffusionPipelineOutput(BaseOutput): from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline else: from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline + from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline try: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index c6bc789aab03..8f59b135bde2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -64,10 +64,13 @@ >>> for url in [source_emb_url, target_emb_url]: ... download(url, url.split("/")[-1]) + + >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) + >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) >>> images = pipeline( ... prompt, - ... source_embedding_path=source_emb_url.split("/")[-1], - ... target_embedding_path=target_emb_url.split("/")[-1], + ... source_embeds=src_embeds, + ... target_embeds=target_embeds, ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... ).images @@ -104,10 +107,8 @@ def prepare_unet(unet: UNet2DConditionModel): return unet -def construct_direction(source_embedding_path: str, target_embedding_path: str): +def construct_direction(embs_source: torch.Tensor, embs_target: torch.Tensor): """Constructs the edit direction to steer the image generation process semantically.""" - embs_source = torch.load(source_embedding_path) - embs_target = torch.load(target_embedding_path) return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) @@ -126,7 +127,7 @@ class Pix2PixZeroCrossAttnProcessor: def __init__(self, is_pix2pix_zero=False): self.is_pix2pix_zero = is_pix2pix_zero if self.is_pix2pix_zero: - self.xa_map = {} + self.reference_cross_attn_map = {} def __call__( self, @@ -157,10 +158,10 @@ def __call__( if self.is_pix2pix_zero and timestep is not None: # new bookkeeping to save the attention weights. if loss is None: - self.xa_map[timestep.item()] = attention_probs.detach().cpu() + self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() # compute loss elif loss is not None: - prev_attn_probs = self.xa_map.pop(timestep.item()) + prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) hidden_states = torch.bmm(attention_probs, value) @@ -487,8 +488,8 @@ def check_inputs( prompt, conditions_input_image, image, - source_embedding_path, - target_embedding_path, + source_embeds, + target_embeds, callback_steps, prompt_embeds=None, ): @@ -499,8 +500,8 @@ def check_inputs( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) - if source_embedding_path is None and target_embedding_path is None: - raise ValueError("`source_embedding_path` and `target_embedding_path` cannot be undefined.") + if source_embeds is None and target_embeds is None: + raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") if prompt is None and not conditions_input_image: raise ValueError(f"`prompt` cannot be None when `conditions_input_image` is {conditions_input_image}") @@ -552,6 +553,8 @@ def __call__( self, prompt: Optional[Union[str, List[str]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + source_embeds: torch.Tensor = None, + target_embeds: torch.Tensor = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -564,8 +567,6 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, cross_attention_guidance_amount: float = 0.1, - source_embedding_path: str = None, - target_embedding_path: str = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -581,6 +582,12 @@ def __call__( instead. image (`PIL.Image.Image`, *optional*): `Image`, or tensor representing an image batch which will be used for conditioning. + source_embeds (`torch.Tensor`): + Source concept embeddings. Generation of the embeddings as per the [original + paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. + target_embeds (`torch.Tensor`): + Target concept embeddings. Generation of the embeddings as per the [original + paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -619,12 +626,6 @@ def __call__( argument. cross_attention_guidance_amount (`float`, defaults to 0.1): Amount of guidance needed from the reference cross-attention maps. - source_embedding_path (`str`, defaults to None): - Local filepath to the embeddings of the source concept. Generation of the embeddings as per the - [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. - target_embedding_path (`str`, defaults to None): - Local filepath to the embeddings of the target concept. Generation of the embeddings as per the - [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -656,8 +657,8 @@ def __call__( prompt, self.conditions_input_image, image, - source_embedding_path, - target_embedding_path, + source_embeds, + target_embeds, callback_steps, prompt_embeds, ) @@ -766,7 +767,7 @@ def __call__( callback(i, t, latents) # 8. Compute the edit directions. - edit_direction = construct_direction(source_embedding_path, target_embedding_path).to(prompt_embeds.device) + edit_direction = construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index cc607138758f..7afa1af2e7dc 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -408,9 +408,9 @@ def requires_backends(obj, backends): " --upgrade transformers \n```" ) - if name in [ - "StableDiffusionDepth2ImgPipeline", - ] and is_transformers_version("<", "4.26.0"): + if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( + "<", "4.26.0" + ): raise ImportError( f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" " --upgrade transformers \n```" diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index ad8d4b176476..7a67fc60883d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -102,6 +102,9 @@ def get_dummy_inputs(self, device, seed=0): for url in [src_emb_url, tgt_emb_url]: download_from_url(url, url.split("/")[-1]) + src_embeds = torch.load(src_emb_url.split("/")[-1]) + target_embeds = torch.load(tgt_emb_url.split("/")[-1]) + generator = torch.manual_seed(seed) inputs = { @@ -110,8 +113,8 @@ def get_dummy_inputs(self, device, seed=0): "num_inference_steps": 2, "guidance_scale": 6.0, "cross_attention_guidance_amount": 0.15, - "source_embedding_path": src_emb_url.split("/")[-1], - "target_embedding_path": tgt_emb_url.split("/")[-1], + "source_embeds": src_embeds, + "target_embeds": target_embeds, "output_type": "numpy", } return inputs @@ -214,14 +217,17 @@ def get_inputs(self, seed=0): for url in [src_emb_url, tgt_emb_url]: download_from_url(url, url.split("/")[-1]) + src_embeds = torch.load(src_emb_url.split("/1")[-1]) + target_embeds = torch.load(tgt_emb_url.split("/1")[-1]) + inputs = { "prompt": "turn him into a cyborg", "generator": generator, "num_inference_steps": 3, "guidance_scale": 7.5, "cross_attention_guidance_amount": 0.15, - "source_embedding_path": src_emb_url.split("/")[-1], - "target_embedding_path": tgt_emb_url.split("/")[-1], + "source_embeds": src_embeds, + "target_embeds": target_embeds, "output_type": "numpy", } return inputs From ca8855e8daeabcef71c0c2be7786b4161d54c151 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 17:10:13 +0530 Subject: [PATCH 57/63] address PR comments (part II) --- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++-- .../test_stable_diffusion_pix2pix_zero.py | 22 +++++++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 8f59b135bde2..714059ead719 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -59,8 +59,8 @@ >>> pipeline.to("cuda") >>> prompt = "a high resolution painting of a cat in the style of van gough" - >>> source_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" - >>> target_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" + >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" + >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" >>> for url in [source_emb_url, target_emb_url]: ... download(url, url.split("/")[-1]) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 7a67fc60883d..c77c2a277b69 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -24,6 +24,7 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, + DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, StableDiffusionPix2PixZeroPipeline, @@ -171,6 +172,23 @@ def test_stable_diffusion_pix2pix_zero_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_pix2pix_zero_ddpm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = DDPMScheduler() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5106, 0.5788, 0.5447, 0.5566, 0.5276, 0.5851, 0.4967, 0.4903, 0.5216]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -211,8 +229,8 @@ def tearDown(self): def get_inputs(self, seed=0): generator = torch.manual_seed(seed) - src_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" - tgt_emb_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" + src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" + tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" for url in [src_emb_url, tgt_emb_url]: download_from_url(url, url.split("/")[-1]) From 46cdfc8589ff2e62e0cdffbab68198e999746cee Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 17:22:03 +0530 Subject: [PATCH 58/63] fix: DDPM test assertion. --- .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index c77c2a277b69..524b69defbad 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -185,7 +185,7 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5106, 0.5788, 0.5447, 0.5566, 0.5276, 0.5851, 0.4967, 0.4903, 0.5216]) + expected_slice = np.array([0.5185, 0.5027, 0.492, 0.401, 0.3445, 0.464, 0.5321, 0.5327, 0.4892]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From 75f918ec29b4905f13ac5a5f2adb27f0d4cc11b9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 17:28:07 +0530 Subject: [PATCH 59/63] refactor doc. --- .../stable_diffusion/pix2pix_zero.mdx | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx index eafc6b77ec2f..bf6888e5a2cd 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx @@ -28,15 +28,15 @@ Resources: ## Tips -* The pipeline exposes two arguments namely `source_embedding_path` and `target_embedding_path` +* The pipeline exposes two arguments namely `source_embeds` and `target_embeds` that let you control the direction of the semantic edits in the final image to be generated. Let's say, you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to -`source_embedding_path` and "dog" to `target_embedding_path`. +`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details. * When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gough". -* If you wanted to reverse the direction i.e., "dog -> cat", then it's recommended to: - * Swap the `source_embedding_path` and `target_embedding_path`. +* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to: + * Swap the `source_embeds` and `target_embeds`. * Change the input prompt to include "dog". ## Available Pipelines: @@ -49,6 +49,8 @@ the above example, a valid input prompt would be: "a high resolution painting of ## Usage example +**Based on an image generated with the input prompt** + ```python import requests import torch @@ -70,22 +72,29 @@ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.to("cuda") prompt = "a high resolution painting of a cat in the style of van gough" -source_embedding_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" -target_embedding_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" +src_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt" +target_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt" -for url in [source_embedding_url, target_embedding_url]: +for url in [src_embs_url, target_embs_url]: download(url, url.split("/")[-1]) +src_embeds = torch.load(src_embs_url.split("/")[-1]) +target_embeds = torch.load(target_embs_url.split("/")[-1]) + images = pipeline( prompt, - source_embedding_path=source_embedding_url.split("/")[-1], - target_embedding_path=target_embedding_url.split("/")[-1], + source_embeds=src_embeds, + target_embeds=target_embeds, num_inference_steps=50, cross_attention_guidance_amount=0.15, ).images images[0].save("edited_image_dog.png") ``` +**Based on an input image** + +_Coming soon_ + ## StableDiffusionPix2PixZeroPipeline [[autodoc]] StableDiffusionPix2PixZeroPipeline - __call__ From a86277e6ba392ecf6fb1c433a8f609925e3874eb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 17:37:14 +0530 Subject: [PATCH 60/63] address PR comments (part III). --- .../pipeline_stable_diffusion_pix2pix_zero.py | 56 ++++++++----------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 714059ead719..34fb22de99ab 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -17,13 +17,7 @@ import PIL import torch -from transformers import ( - AutoProcessor, - BlipForConditionalGeneration, - CLIPFeatureExtractor, - CLIPTextModel, - CLIPTokenizer, -) +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...models.cross_attention import CrossAttention @@ -79,17 +73,6 @@ """ -def generate_caption(image, captioner, processor, return_image=True): - """Generates caption for a given image.""" - inputs = processor(images=image, return_tensors="pt") - outputs = captioner.generate(inputs) - caption = processor.batch_deocde(outputs, skip_special_tokens=True)[0] - if return_image: - return caption, inputs["pixel_values"] - else: - return caption - - def prepare_unet(unet: UNet2DConditionModel): """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" pix2pix_zero_attn_procs = {} @@ -107,11 +90,6 @@ def prepare_unet(unet: UNet2DConditionModel): return unet -def construct_direction(embs_source: torch.Tensor, embs_target: torch.Tensor): - """Constructs the edit direction to steer the image generation process semantically.""" - return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) - - class Pix2PixZeroL2Loss: def __init__(self): self.loss = 0.0 @@ -234,10 +212,11 @@ def __init__( ) if conditions_input_image: - logger.info("Loading caption generator since `conditions_input_image` is True.") - checkpoint = "Salesforce/blip-image-captioning-base" - captioner_processor = AutoProcessor.from_pretrained(checkpoint) - captioner = BlipForConditionalGeneration.from_pretrained(checkpoint) + raise NotImplementedError + # logger.info("Loading caption generator since `conditions_input_image` is True.") + # checkpoint = "Salesforce/blip-image-captioning-base" + # captioner_processor = AutoProcessor.from_pretrained(checkpoint) + # captioner = BlipForConditionalGeneration.from_pretrained(checkpoint, dtype=unet.dtype) else: captioner_processor = None captioner = None @@ -547,6 +526,20 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + def generate_caption(self, image, return_image=True): + """Generates caption for a given image.""" + inputs = self._captioner_processor(images=image, return_tensors="pt") + outputs = self._captioner.generate(inputs) + caption = self._captioner_processor.batch_deocde(outputs, skip_special_tokens=True)[0] + if return_image: + return caption, inputs["pixel_values"] + else: + return caption + + def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): + """Constructs the edit direction to steer the image generation process semantically.""" + return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -671,10 +664,9 @@ def __call__( # 2. Generate a caption for the input image if we are conditioning the # pipeline based on some input image. if self.conditions_input_image: - caption, preprocessed_image = generate_caption(image, self._captioner, self._captioner_processor) + prompt, preprocessed_image = self.generate_caption(image) height, width = preprocessed_image.shape[-2:] - prompt = caption - logger.info(f"Generated caption for the input image: {caption}.") + logger.info(f"Generated prompt for the input image: {prompt}.") # 3. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -714,7 +706,7 @@ def __call__( # We need to get the inverted noise from the input image and this requires # us to do a sort of `inverse_step()` in DDIM and then regularize the # noise to enforce the statistical properties of Gaussian. - raise NotImplementedError + pass else: num_channels_latents = self.unet.in_channels latents = self.prepare_latents( @@ -767,7 +759,7 @@ def __call__( callback(i, t, latents) # 8. Compute the edit directions. - edit_direction = construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) + edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() From 696d802723969c78ece27bf5f6160344148c5217 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 17:43:36 +0530 Subject: [PATCH 61/63] fix: type annotation for the scheduler. --- .../pipeline_stable_diffusion_pix2pix_zero.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 34fb22de99ab..5dbd77429b7a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -21,7 +21,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.cross_attention import CrossAttention -from ...schedulers import KarrasDiffusionSchedulers +from ...schedulers import DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -173,7 +173,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. @@ -193,7 +193,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, + scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, conditions_input_image: bool = False, From d166a476504f238b89b293ece0c69bfe948a3695 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Feb 2023 17:44:00 +0530 Subject: [PATCH 62/63] apply styling. --- .../stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 5dbd77429b7a..f0f09b8739b2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -21,7 +21,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.cross_attention import CrossAttention -from ...schedulers import DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler +from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput From 4ff6f12bbd16970b0d0714f20aa5f2a25807f88e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Feb 2023 08:40:48 +0530 Subject: [PATCH 63/63] skip_mps and add note on embeddings in the docs. --- docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx | 2 ++ .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx index bf6888e5a2cd..e4c26a182f5e 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx @@ -38,6 +38,8 @@ the above example, a valid input prompt would be: "a high resolution painting of * If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to: * Swap the `source_embeds` and `target_embeds`. * Change the input prompt to include "dog". +* To learn more about how the source and target embeddings are generated, refer to the [original +paper](https://arxiv.org/abs/2302.03027). ## Available Pipelines: diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 524b69defbad..97012bd4be73 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -31,7 +31,7 @@ UNet2DConditionModel, ) from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from ...test_pipelines_common import PipelineTesterMixin @@ -45,6 +45,7 @@ def download_from_url(embedding_url, local_filepath): f.write(r.content) +@skip_mps class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPix2PixZeroPipeline