From 2771b98a00f28c8ec09472eee1bab1bbcf4ea4ba Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 22 Jul 2024 11:57:48 +0200 Subject: [PATCH 01/22] copy hunyuandit pipeline --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/pag/__init__.py | 2 + .../pipelines/pag/pipeline_pag_hunyuandit.py | 900 ++++++++++++++++++ 4 files changed, 906 insertions(+) create mode 100644 src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6a6607cc376f..837e2433a219 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -244,6 +244,7 @@ "CLIPImageProjection", "CycleDiffusionPipeline", "HunyuanDiTControlNetPipeline", + "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", @@ -656,6 +657,7 @@ CLIPImageProjection, CycleDiffusionPipeline, HunyuanDiTControlNetPipeline, + HunyuanDiTPAGPipeline, HunyuanDiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1d5fd5c2d094..3899ea79c598 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -141,6 +141,7 @@ ) _import_structure["pag"].extend( [ + "HunyuanDiTPAGPipeline", "StableDiffusionPAGPipeline", "StableDiffusionControlNetPAGPipeline", "StableDiffusionXLPAGPipeline", @@ -515,6 +516,7 @@ ) from .musicldm import MusicLDMPipeline from .pag import ( + HunyuanDiTPAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGPipeline, diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index bf14821f3fdb..62a672ddd8c1 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] + _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] @@ -39,6 +40,7 @@ else: from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline + from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py new file mode 100644 index 000000000000..868e103d15d9 --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -0,0 +1,900 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanDiTPipeline + + >>> pipe = HunyuanDiTPipeline.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese + >>> # prompt = "An astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPAGPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: Optional[StableDiffusionSafetyChecker] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + requires_safety_checker: bool = True, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 659310aeaebca319fbdd8cbdb55dcae861082c9f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 22 Jul 2024 19:21:38 +0200 Subject: [PATCH 02/22] pag variant of hunyuan dit --- src/diffusers/models/attention_processor.py | 243 ++++++++++++++++++ src/diffusers/pipelines/auto_pipeline.py | 2 + src/diffusers/pipelines/pag/pag_utils.py | 89 ++++++- .../pipelines/pag/pipeline_pag_hunyuandit.py | 84 ++++-- 4 files changed, 400 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cadf793953c6..4dee0c54a7c7 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1708,6 +1708,249 @@ def __call__( return hidden_states +class PAGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 2df09f62c880..acfb369daf1c 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -49,6 +49,7 @@ from .kolors import KolorsImg2ImgPipeline, KolorsPipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .pag import ( + HunyuanDiTPAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGPipeline, @@ -83,6 +84,7 @@ ("stable-diffusion-3", StableDiffusion3Pipeline), ("if", IFPipeline), ("hunyuan", HunyuanDiTPipeline), + ("hunyuan-pag", HunyuanDiTPAGPipeline), ("kandinsky", KandinskyCombinedPipeline), ("kandinsky22", KandinskyV22CombinedPipeline), ("kandinsky3", Kandinsky3Pipeline), diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 2009024e4e47..7bc5b12b3d23 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -15,7 +15,9 @@ import torch from ...models.attention_processor import ( + PAGCFGHunyuanAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, + PAGHunyuanAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) from ...utils import logging @@ -255,6 +257,91 @@ def pag_attn_processors(self): processors = {} for name, proc in self.unet.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): + if proc.__class__ in (PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0): + processors[name] = proc + return processors + + +class HunyuanDiTPAGMixin(PAGMixin): + r"""Mixin class for PAG applied to HunyuanDiT.""" + + @staticmethod + def _check_input_pag_applied_layer(layer): + r""" + Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}. + """ + + # Check if the layer index is valid (should be int or str of int) + if isinstance(layer, int): + return # Valid layer index + + if isinstance(layer, str): + if layer.isdigit(): + return # Valid layer index + + # If it is not a valid layer index, raise a ValueError + raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}") + + def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + if do_classifier_free_guidance: + pag_attn_proc = PAGCFGHunyuanAttnProcessor2_0() + else: + pag_attn_proc = PAGHunyuanAttnProcessor2_0() + + def is_self_attn(module_name): + r""" + Check if the module is self-attention module based on its name. + """ + # include blocks.1.attn1 + # exclude blocks.18.attn1.to_q, blocks.1.attn1.norm_k, ... + return "attn1" in module_name and len(module_name.split(".")) == 3 + + def get_block_index(module_name): + r""" + Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. + mid_block) and index is ommited from the name, it will be "block_0". + """ + # blocks.23.attn1 -> "23" + return module_name.split(".")[1] + + for pag_layer_input in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the transformer model + target_modules = [] + pag_layer_input_splits = str(pag_layer_input).split(".") + + if len(pag_layer_input_splits) == 1: + # 20, "20" -> "20" + block_index = pag_layer_input_splits[0] + elif len(pag_layer_input_splits) >= 2: + # "blocks.20" -> "20" + # "blocks.20.attn1" -> "20" + block_index = pag_layer_input_splits[1] + + for name, module in self.transformer.named_modules(): + if is_self_attn(name) and get_block_index(name) == block_index: + target_modules.append(module) + + if len(target_modules) == 0: + raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") + + for module in target_modules: + module.processor = pag_attn_proc + + @property + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer + def pag_attn_processors(self): + r""" + Returns: + `dict`: + A dictionary contains all PAG attention processors used in the model with the key as the name of the + layer. + """ + + processors = {} + for name, proc in self.transformer.attn_processors.items(): + if proc.__class__ in (PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0): processors[name] = proc return processors diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 868e103d15d9..d2c64854aada 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -34,6 +34,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .pag_utils import HunyuanDiTPAGMixin if is_torch_xla_available(): @@ -50,17 +51,18 @@ Examples: ```py >>> import torch - >>> from diffusers import HunyuanDiTPipeline + >>> from diffusers import AutoPipelineForText2Image - >>> pipe = HunyuanDiTPipeline.from_pretrained( - ... "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16 - ... ) - >>> pipe.to("cuda") + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=[16, 17, 18, 19] + >>> ).to("cuda") - >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese - >>> # prompt = "An astronaut riding a horse" + >>> # prompt = "an astronaut riding a horse" >>> prompt = "一个宇航员在骑马" - >>> image = pipe(prompt).images[0] + >>> image = pipe(, guidance_scale=4, pag_scale=3).images[0] ``` """ @@ -138,9 +140,10 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class HunyuanDiTPAGPipeline(DiffusionPipeline): +class HunyuanDiTPAGPipeline(DiffusionPipeline, HunyuanDiTPAGMixin): r""" - Pipeline for English/Chinese-to-image generation using HunyuanDiT. + Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) @@ -197,6 +200,7 @@ def __init__( requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, tokenizer_2: Optional[MT5Tokenizer] = None, + pag_applied_layers: Union[str, List[str]] = [], # "blocks.16.attn1", "blocks.16", "16", 16 ): super().__init__() @@ -239,6 +243,9 @@ def __init__( else 128 ) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt def encode_prompt( self, prompt: str, @@ -437,6 +444,7 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs def check_inputs( self, prompt, @@ -592,6 +600,8 @@ def __call__( target_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), use_resolution_binning: bool = True, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, ): r""" The call function to the pipeline for generation with HunyuanDiT. @@ -663,6 +673,12 @@ def __call__( Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. Examples: @@ -677,7 +693,7 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - # 0. default height and width + # 0. Default height and width height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor height = int((height // 16) * 16) @@ -708,6 +724,8 @@ def __call__( self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -720,7 +738,6 @@ def __call__( device = self._execution_device # 3. Encode input prompt - ( prompt_embeds, negative_prompt_embeds, @@ -780,7 +797,7 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7 create image_rotary_emb, style embedding & time ids + # 7. Create image_rotary_emb, style embedding & time ids grid_height = height // 8 // self.transformer.config.patch_size grid_width = width // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size @@ -795,7 +812,25 @@ def __call__( add_time_ids = list(original_size + target_size + crops_coords_top_left) add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) - if self.do_classifier_free_guidance: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + prompt_embeds_2 = self._prepare_perturbed_attention_guidance( + prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance + ) + prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance( + prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance + ) + add_time_ids = torch.cat([add_time_ids] * 3, dim=0) + style = torch.cat([style] * 3, dim=0) + elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) @@ -815,13 +850,21 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input @@ -846,7 +889,11 @@ def __call__( noise_pred, _ = noise_pred.chunk(2, dim=1) # perform guidance - if self.do_classifier_free_guidance: + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -891,9 +938,12 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - # Offload all models + # 9. Offload all models self.maybe_free_model_hooks() + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + if not return_dict: return (image, has_nsfw_concept) From 21af0a7a68d3dcf6a150adcdfe6747dfffa232be Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 22 Jul 2024 19:21:51 +0200 Subject: [PATCH 03/22] add tests --- tests/pipelines/pag/test_pag_hunyuan_dit.py | 364 ++++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 tests/pipelines/pag/test_pag_hunyuan_dit.py diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py new file mode 100644 index 000000000000..f31a6e97a7e8 --- /dev/null +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -0,0 +1,364 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, BertModel, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + HunyuanDiT2DModel, + HunyuanDiTPAGPipeline, + HunyuanDiTPipeline, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanDiTPAGPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanDiT2DModel( + sample_size=16, + num_layers=2, + patch_size=2, + attention_head_dim=8, + num_attention_heads=3, + in_channels=4, + cross_attention_dim=32, + cross_attention_dim_t5=32, + pooled_projection_dim=16, + hidden_size=24, + activation_fn="gelu-approximate", + ) + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = DDPMScheduler() + text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel") + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "safety_checker": None, + "feature_extractor": None, + "pag_applied_layers": [1], + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "use_resolution_binning": False, + "pag_scale": 0.0, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 16, 16, 3)) + expected_slice = np.array( + [0.56939435, 0.34541583, 0.35915792, 0.46489206, 0.38775963, 0.45004836, 0.5957267, 0.59481275, 0.33287364] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_sequential_cpu_offload_forward_pass(self): + # TODO(YiYi) need to fix later + pass + + def test_sequential_offload_forward_pass_twice(self): + # TODO(YiYi) need to fix later + pass + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-3, + ) + + def test_save_load_optional_components(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0) + + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = pipe.encode_prompt( + prompt, + device=torch_device, + dtype=torch.float32, + text_encoder_index=1, + ) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=components["pag_applied_layers"]) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1e-4) + + def test_feed_forward_chunking(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_no_chunking = image[0, -3:, -3:, -1] + + pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0) + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_chunking = image[0, -3:, -3:, -1] + + max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max() + self.assertLess(max_diff, 1e-4) + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image = pipe(**inputs)[0] + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_fused = pipe(**inputs)[0] + image_slice_fused = image_fused[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + inputs["return_dict"] = False + image_disabled = pipe(**inputs)[0] + image_slice_disabled = image_disabled[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline (expect same output when pag is disabled) + components.pop("pag_applied_layers", None) + pipe_sd = HunyuanDiTPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + components = self.get_dummy_components() + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 3.0 + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_pag_applied_layers(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline + components.pop("pag_applied_layers", None) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] + original_attn_procs = pipe.transformer.attn_processors + pag_layers = [0, 1] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) + + # blocks.0 + block_0_self_attn = ["blocks.0.attn1.processor"] + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = [0] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(block_0_self_attn) + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["blocks.0.attn1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(block_0_self_attn) + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = [0, "1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert (len(pipe.pag_attn_processors)) == 2 + + pipe.transformer.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["0", "blocks.1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert len(pipe.pag_attn_processors) == 2 From e26bce5a4f2803a6cf83b62a3038c0682e02dd2b Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 22 Jul 2024 19:22:12 +0200 Subject: [PATCH 04/22] update docs --- docs/source/en/api/pipelines/pag.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index abfeb930d5ba..849d89120ff2 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -20,6 +20,11 @@ The abstract from the paper is: *Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.* +## HunyuanDiTPAGPipeline +[[autodoc]] HunyuanDiTPAGPipeline + - all + - __call__ + ## StableDiffusionPAGPipeline [[autodoc]] StableDiffusionPAGPipeline - all From 8c7ca9edf0d643a4b0ab368e1407830b4ae61b81 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 22 Jul 2024 19:26:03 +0200 Subject: [PATCH 05/22] make style --- src/diffusers/models/attention_processor.py | 20 +++++++++++-------- src/diffusers/pipelines/pag/pag_utils.py | 4 ++-- .../pipelines/pag/pipeline_pag_hunyuandit.py | 10 +++++----- tests/pipelines/pag/test_pag_hunyuan_dit.py | 11 +++++----- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4dee0c54a7c7..68d7c088a66d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1717,7 +1717,9 @@ class PAGHunyuanAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -1739,7 +1741,7 @@ def __call__( if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - + # chunk hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) @@ -1802,7 +1804,7 @@ def __call__( if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) - + # 2. Perturbed Path if attn.group_norm is not None: hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) @@ -1817,7 +1819,7 @@ def __call__( if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) - + # cat hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) @@ -1838,7 +1840,9 @@ class PAGCFGHunyuanAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -1860,7 +1864,7 @@ def __call__( if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - + # chunk hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) @@ -1924,7 +1928,7 @@ def __call__( if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) - + # 2. Perturbed Path if attn.group_norm is not None: hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) @@ -1939,7 +1943,7 @@ def __call__( if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) - + # cat hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 7bc5b12b3d23..eade3af3a2dc 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -319,11 +319,11 @@ def get_block_index(module_name): # "blocks.20" -> "20" # "blocks.20.attn1" -> "20" block_index = pag_layer_input_splits[1] - + for name, module in self.transformer.named_modules(): if is_self_attn(name) and get_block_index(name) == block_index: target_modules.append(module) - + if len(target_modules) == 0: raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index d2c64854aada..7a6deeebd127 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -49,7 +49,7 @@ EXAMPLE_DOC_STRING = """ Examples: - ```py + ```python >>> import torch >>> from diffusers import AutoPipelineForText2Image @@ -57,12 +57,12 @@ ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", ... torch_dtype=torch.float16, ... enable_pag=True, - ... pag_applied_layers=[16, 17, 18, 19] - >>> ).to("cuda") + ... pag_applied_layers=[16, 17, 18, 19], + ... ).to("cuda") >>> # prompt = "an astronaut riding a horse" >>> prompt = "一个宇航员在骑马" - >>> image = pipe(, guidance_scale=4, pag_scale=3).images[0] + >>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0] ``` """ @@ -200,7 +200,7 @@ def __init__( requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, tokenizer_2: Optional[MT5Tokenizer] = None, - pag_applied_layers: Union[str, List[str]] = [], # "blocks.16.attn1", "blocks.16", "16", 16 + pag_applied_layers: Union[str, List[str]] = [], # "blocks.16.attn1", "blocks.16", "16", 16 ): super().__init__() diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py index f31a6e97a7e8..549d8e8a7630 100644 --- a/tests/pipelines/pag/test_pag_hunyuan_dit.py +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -30,9 +30,6 @@ ) from diffusers.utils.testing_utils import ( enable_full_determinism, - numpy_cosine_similarity_distance, - require_torch_gpu, - slow, torch_device, ) @@ -194,7 +191,9 @@ def test_save_load_optional_components(self): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=components["pag_applied_layers"]) + pipe_loaded = self.pipeline_class.from_pretrained( + tmpdir, pag_applied_layers=components["pag_applied_layers"] + ) pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) @@ -284,7 +283,7 @@ def test_fused_qkv_projections(self): assert np.allclose( original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." - + def test_pag_disable_enable(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -324,7 +323,7 @@ def test_pag_disable_enable(self): assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 - + def test_pag_applied_layers(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() From f83a4638901b846731f405c5774b1ef5cfd17854 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 22 Jul 2024 19:26:49 +0200 Subject: [PATCH 06/22] make fix-copies --- src/diffusers/pipelines/pag/pag_utils.py | 5 ++--- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index eade3af3a2dc..30e0bdd04fbc 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -335,9 +335,8 @@ def get_block_index(module_name): def pag_attn_processors(self): r""" Returns: - `dict`: - A dictionary contains all PAG attention processors used in the model with the key as the name of the - layer. + `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model + with the key as the name of the layer. """ processors = {} diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 399656d8c185..b16774d091e0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -272,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanDiTPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanDiTPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 196e687c05ab8d82a36b1c4f9fa256c1056f6f9e Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 23 Jul 2024 05:27:40 +0530 Subject: [PATCH 07/22] Update src/diffusers/pipelines/pag/pag_utils.py --- src/diffusers/pipelines/pag/pag_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 30e0bdd04fbc..9510a5b1f970 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -257,7 +257,7 @@ def pag_attn_processors(self): processors = {} for name, proc in self.unet.attn_processors.items(): - if proc.__class__ in (PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0): + if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): processors[name] = proc return processors From 7998502104bea903f5e629e00ab44b13ec9a11dd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 23 Jul 2024 11:01:57 +0200 Subject: [PATCH 08/22] remove incorrect copied from --- src/diffusers/pipelines/pag/pag_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 9510a5b1f970..4b810d6b19c1 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -331,7 +331,6 @@ def get_block_index(module_name): module.processor = pag_attn_proc @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer def pag_attn_processors(self): r""" Returns: From 096ded39e2c6673d4c2a26a4afc268975ea227ea Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 13:09:23 +0200 Subject: [PATCH 09/22] remove pag hunyuan attn procs to resolve conflicts --- src/diffusers/models/attention_processor.py | 247 -------------------- 1 file changed, 247 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 68d7c088a66d..cadf793953c6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1708,253 +1708,6 @@ def __call__( return hidden_states -class PAGHunyuanAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This - variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - from .embeddings import apply_rotary_emb - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - # chunk - hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) - - # 1. Original Path - batch_size, sequence_length, _ = ( - hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states_org) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states_org - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - if not attn.is_cross_attention: - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states_org = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_org = hidden_states_org.to(query.dtype) - - # linear proj - hidden_states_org = attn.to_out[0](hidden_states_org) - # dropout - hidden_states_org = attn.to_out[1](hidden_states_org) - - if input_ndim == 4: - hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # 2. Perturbed Path - if attn.group_norm is not None: - hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) - - hidden_states_ptb = attn.to_v(hidden_states_ptb) - hidden_states_ptb = hidden_states_ptb.to(query.dtype) - - # linear proj - hidden_states_ptb = attn.to_out[0](hidden_states_ptb) - # dropout - hidden_states_ptb = attn.to_out[1](hidden_states_ptb) - - if input_ndim == 4: - hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # cat - hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class PAGCFGHunyuanAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This - variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - from .embeddings import apply_rotary_emb - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - # chunk - hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) - hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) - - # 1. Original Path - batch_size, sequence_length, _ = ( - hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states_org) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states_org - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - if not attn.is_cross_attention: - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states_org = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_org = hidden_states_org.to(query.dtype) - - # linear proj - hidden_states_org = attn.to_out[0](hidden_states_org) - # dropout - hidden_states_org = attn.to_out[1](hidden_states_org) - - if input_ndim == 4: - hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # 2. Perturbed Path - if attn.group_norm is not None: - hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) - - hidden_states_ptb = attn.to_v(hidden_states_ptb) - hidden_states_ptb = hidden_states_ptb.to(query.dtype) - - # linear proj - hidden_states_ptb = attn.to_out[0](hidden_states_ptb) - # dropout - hidden_states_ptb = attn.to_out[1](hidden_states_ptb) - - if input_ndim == 4: - hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # cat - hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is From d397ed4588a68a423664fb0d32e65011df626855 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 13:10:01 +0200 Subject: [PATCH 10/22] add pag attn procs again --- src/diffusers/models/attention_processor.py | 247 ++++++++++++++++++++ 1 file changed, 247 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ad00147ab31f..54f5cfdc8b1e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1825,6 +1825,253 @@ def __call__( return hidden_states +class PAGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGHunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This + variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # 1. Original Path + batch_size, sequence_length, _ = ( + hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states_org + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # 2. Perturbed Path + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + hidden_states_ptb = attn.to_v(hidden_states_ptb) + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is From d9638d9ea3d988293cfbf8277661ef9fe230d0d2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 26 Jul 2024 22:33:24 +0200 Subject: [PATCH 11/22] new implementation for pag_utils --- src/diffusers/models/attention_processor.py | 2 + src/diffusers/pipelines/pag/pag_utils.py | 283 +++++------------- .../pipelines/pag/pipeline_pag_hunyuandit.py | 11 +- 3 files changed, 91 insertions(+), 205 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7ff7c0ac436c..7d9dd3db0d82 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3388,4 +3388,6 @@ def __init__(self): CustomDiffusionAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, + PAGCFGHunyuanAttnProcessor2_0, + PAGHunyuanAttnProcessor2_0, ] diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 4b810d6b19c1..afc8d1a2cecb 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +from typing import List, Optional, Tuple, Union + import torch +import torch.nn as nn from ...models.attention_processor import ( - PAGCFGHunyuanAttnProcessor2_0, + AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, - PAGHunyuanAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) from ...utils import logging @@ -27,123 +30,60 @@ class PAGMixin: - r"""Mixin class for PAG.""" - - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats: - "{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type` - can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be - in the format of "attentions_{j}". - """ - - layer_splits = layer.split(".") - - if len(layer_splits) > 3: - raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") - - if len(layer_splits) >= 1: - if layer_splits[0] not in ["down", "mid", "up"]: - raise ValueError( - f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}" - ) - - if len(layer_splits) >= 2: - if not layer_splits[1].startswith("block_"): - raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'") - - if len(layer_splits) == 3: - if not layer_splits[2].startswith("attentions_"): - raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'") + r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): r""" Set the attention processor for the PAG layers. """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() + pag_attn_processors = getattr(self, "_pag_attn_processors", None) + if pag_attn_processors is None: + # If this hasn't been set by the user, we default to the original PAG identity processors + pag_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()) + + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + if hasattr(self, "unet"): + model: nn.Module = self.unet else: - pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() + model: nn.Module = self.transformer - def is_self_attn(module_name): + def is_self_attn(module_name: str) -> bool: r""" Check if the module is self-attention module based on its name. """ - return "attn1" in module_name and "to" not in name + attn_id = getattr(self, "_self_attn_identifier", "attn1") - def get_block_type(module_name): - r""" - Get the block type from the module name. can be "down", "mid", "up". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" - return module_name.split(".")[0].split("_")[0] + # here, we check for "to" because we want the attention processor itself and not the qkv layers + qkv_present = "to" in module_name + norm_present = "norm" in module_name + return attn_id in module_name and not qkv_present and not norm_present - def get_block_index(module_name): - r""" - Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0" - if "attentions" in module_name.split(".")[1]: - return "block_0" - else: - return f"block_{module_name.split('.')[1]}" - - def get_attn_index(module_name): - r""" - Get the attention index from the module name. can be "attentions_0", "attentions_1", ... - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - if "attentions" in module_name.split(".")[2]: - return f"attentions_{module_name.split('.')[3]}" - elif "attentions" in module_name.split(".")[1]: - return f"attentions_{module_name.split('.')[2]}" - - for pag_layer_input in pag_applied_layers: + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + for layer_id in pag_applied_layers: # for each PAG layer input, we find corresponding self-attention layers in the unet model target_modules = [] - pag_layer_input_splits = pag_layer_input.split(".") - - if len(pag_layer_input_splits) == 1: - # when the layer input only contains block_type. e.g. "mid", "down", "up" - block_type = pag_layer_input_splits[0] - for name, module in self.unet.named_modules(): - if is_self_attn(name) and get_block_type(name) == block_type: - target_modules.append(module) - - elif len(pag_layer_input_splits) == 2: - # when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - ): - target_modules.append(module) - - elif len(pag_layer_input_splits) == 3: - # when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - attn_index = pag_layer_input_splits[2] - - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - and get_attn_index(name) == attn_index - ): - target_modules.append(module) + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(name) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + print(f"applying to: {name}") + target_modules.append(module) if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") for module in target_modules: module.processor = pag_attn_proc @@ -206,45 +146,64 @@ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free cond = torch.cat([uncond, cond], dim=0) return cond - def set_pag_applied_layers(self, pag_applied_layers): + def set_pag_applied_layers( + self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Optional[Tuple[AttentionProcessor, AttentionProcessor]] = None, + self_attn_identifier: str = "attn1", + ): r""" - set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `List[str]`): + One or more strings, or simple regex, to identify layers where to apply PAG. + pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, *optional*): + A tuple of two attention processors. The first attention processor is for PAG with Classifier-free + guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG + disabled (unconditional only). + self_attn_identifier (`str`, defaults to "attn1"): + The string to identity self-attn layers. """ if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") - for pag_layer in pag_applied_layers: - self._check_input_pag_applied_layer(pag_layer) + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) self.pag_applied_layers = pag_applied_layers + self._self_attn_identifier = self_attn_identifier + + # Ensure we don't overwrite existing processors, possible from __init__, if the intention + # was to only change the layers where PAG was applied. + if pag_attn_processors is not None: + self._pag_attn_processors = pag_attn_processors @property def pag_scale(self): - """ - Get the scale factor for the perturbed attention guidance. - """ + r"""Get the scale factor for the perturbed attention guidance.""" return self._pag_scale @property def pag_adaptive_scale(self): - """ - Get the adaptive scale factor for the perturbed attention guidance. - """ + r"""Get the adaptive scale factor for the perturbed attention guidance.""" return self._pag_adaptive_scale @property def do_pag_adaptive_scaling(self): - """ - Check if the adaptive scaling is enabled for the perturbed attention guidance. - """ + r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property def do_perturbed_attention_guidance(self): - """ - Check if the perturbed attention guidance is enabled. - """ + r"""Check if the perturbed attention guidance is enabled.""" return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property @@ -255,91 +214,13 @@ def pag_attn_processors(self): with the key as the name of the layer. """ - processors = {} - for name, proc in self.unet.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): - processors[name] = proc - return processors - - -class HunyuanDiTPAGMixin(PAGMixin): - r"""Mixin class for PAG applied to HunyuanDiT.""" - - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}. - """ - - # Check if the layer index is valid (should be int or str of int) - if isinstance(layer, int): - return # Valid layer index - - if isinstance(layer, str): - if layer.isdigit(): - return # Valid layer index - - # If it is not a valid layer index, raise a ValueError - raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}") - - def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGHunyuanAttnProcessor2_0() + if not hasattr(self, "_pag_attn_processors") or self._pag_attn_processors is None: + valid_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0) else: - pag_attn_proc = PAGHunyuanAttnProcessor2_0() - - def is_self_attn(module_name): - r""" - Check if the module is self-attention module based on its name. - """ - # include blocks.1.attn1 - # exclude blocks.18.attn1.to_q, blocks.1.attn1.norm_k, ... - return "attn1" in module_name and len(module_name.split(".")) == 3 - - def get_block_index(module_name): - r""" - Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # blocks.23.attn1 -> "23" - return module_name.split(".")[1] - - for pag_layer_input in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the transformer model - target_modules = [] - pag_layer_input_splits = str(pag_layer_input).split(".") - - if len(pag_layer_input_splits) == 1: - # 20, "20" -> "20" - block_index = pag_layer_input_splits[0] - elif len(pag_layer_input_splits) >= 2: - # "blocks.20" -> "20" - # "blocks.20.attn1" -> "20" - block_index = pag_layer_input_splits[1] - - for name, module in self.transformer.named_modules(): - if is_self_attn(name) and get_block_index(name) == block_index: - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") - - for module in target_modules: - module.processor = pag_attn_proc - - @property - def pag_attn_processors(self): - r""" - Returns: - `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model - with the key as the name of the layer. - """ + valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) processors = {} - for name, proc in self.transformer.attn_processors.items(): - if proc.__class__ in (PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0): + for name, proc in self.unet.attn_processors.items(): + if proc.__class__ in valid_attn_processors: processors[name] = proc return processors diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 7a6deeebd127..36d4146860a7 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -24,6 +24,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0 from ...models.embeddings import get_2d_rotary_pos_embed from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import DDPMScheduler @@ -34,7 +35,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pag_utils import HunyuanDiTPAGMixin +from .pag_utils import PAGMixin if is_torch_xla_available(): @@ -57,7 +58,7 @@ ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", ... torch_dtype=torch.float16, ... enable_pag=True, - ... pag_applied_layers=[16, 17, 18, 19], + ... pag_applied_layers=[14], ... ).to("cuda") >>> # prompt = "an astronaut riding a horse" @@ -140,7 +141,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class HunyuanDiTPAGPipeline(DiffusionPipeline, HunyuanDiTPAGMixin): +class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): r""" Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). @@ -243,7 +244,9 @@ def __init__( else 128 ) - self.set_pag_applied_layers(pag_applied_layers) + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) + ) # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt def encode_prompt( From ab34d385a1b4ef89b2809ac3b1e39deb6156a93d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 12:15:04 +0200 Subject: [PATCH 12/22] revert pag changes --- src/diffusers/pipelines/pag/pag_utils.py | 397 ++++++++++++++++++----- 1 file changed, 315 insertions(+), 82 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index afc8d1a2cecb..7c9bb2d098d2 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -12,14 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from typing import List, Optional, Tuple, Union - import torch -import torch.nn as nn from ...models.attention_processor import ( - AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) @@ -30,60 +25,140 @@ class PAGMixin: - r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" + r"""Mixin class for PAG.""" - def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + @staticmethod + def _check_input_pag_applied_layer(layer): r""" - Set the attention processor for the PAG layers. + Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats: + "{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type` + can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be + in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}" """ - pag_attn_processors = getattr(self, "_pag_attn_processors", None) - if pag_attn_processors is None: - # If this hasn't been set by the user, we default to the original PAG identity processors - pag_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()) - pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + layer_splits = layer.split(".") + + if len(layer_splits) > 3: + raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") + + if len(layer_splits) >= 1: + if layer_splits[0] not in ["down", "mid", "up"]: + raise ValueError( + f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}" + ) + + if len(layer_splits) >= 2: + if not layer_splits[1].startswith("block_"): + raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'") + + if len(layer_splits) == 3: + layer_2 = layer_splits[2] + if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"): + raise ValueError( + f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'" + ) - if hasattr(self, "unet"): - model: nn.Module = self.unet + def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + if do_classifier_free_guidance: + pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() else: - model: nn.Module = self.transformer + pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() - def is_self_attn(module_name: str) -> bool: + def is_self_attn(module_name): r""" Check if the module is self-attention module based on its name. """ - attn_id = getattr(self, "_self_attn_identifier", "attn1") - - # here, we check for "to" because we want the attention processor itself and not the qkv layers - qkv_present = "to" in module_name - norm_present = "norm" in module_name - return attn_id in module_name and not qkv_present and not norm_present + return "attn1" in module_name and "to" not in name - def is_fake_integral_match(layer_id, name): - layer_id = layer_id.split(".")[-1] - name = name.split(".")[-1] - return layer_id.isnumeric() and name.isnumeric() and layer_id == name + def get_block_type(module_name): + r""" + Get the block type from the module name. Can be "down", "mid", "up". + """ + # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" + # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down" + return module_name.split(".")[0].split("_")[0] - for layer_id in pag_applied_layers: + def get_block_index(module_name): + r""" + Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g. + mid_block) and index is ommited from the name, it will be "block_0". + """ + # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1" + # mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0" + module_name_splits = module_name.split(".") + block_index = module_name_splits[1] + if "attentions" in block_index or "motion_modules" in block_index: + return "block_0" + else: + return f"block_{block_index}" + + def get_attn_index(module_name): + r""" + Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0", + "motion_modules_1", ... + """ + # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" + # mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" + # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" + # mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" + module_name_split = module_name.split(".") + mid_name = module_name_split[1] + down_name = module_name_split[2] + if "attentions" in down_name: + return f"attentions_{module_name_split[3]}" + if "attentions" in mid_name: + return f"attentions_{module_name_split[2]}" + if "motion_modules" in down_name: + return f"motion_modules_{module_name_split[3]}" + if "motion_modules" in mid_name: + return f"motion_modules_{module_name_split[2]}" + + for pag_layer_input in pag_applied_layers: # for each PAG layer input, we find corresponding self-attention layers in the unet model target_modules = [] - for name, module in model.named_modules(): - # Identify the following simple cases: - # (1) Self Attention layer existing - # (2) Whether the module name matches pag layer id even partially - # (3) Make sure it's not a fake integral match if the layer_id ends with a number - # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" - if ( - is_self_attn(name) - and re.search(layer_id, name) is not None - and not is_fake_integral_match(layer_id, name) - ): - print(f"applying to: {name}") - target_modules.append(module) + pag_layer_input_splits = pag_layer_input.split(".") + + if len(pag_layer_input_splits) == 1: + # when the layer input only contains block_type. e.g. "mid", "down", "up" + block_type = pag_layer_input_splits[0] + for name, module in self.unet.named_modules(): + if is_self_attn(name) and get_block_type(name) == block_type: + target_modules.append(module) + + elif len(pag_layer_input_splits) == 2: + # when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0" + block_type = pag_layer_input_splits[0] + block_index = pag_layer_input_splits[1] + for name, module in self.unet.named_modules(): + if ( + is_self_attn(name) + and get_block_type(name) == block_type + and get_block_index(name) == block_index + ): + target_modules.append(module) + + elif len(pag_layer_input_splits) == 3: + # when the layer input contains block_type, block_index and attention_index. + # e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1" + block_type = pag_layer_input_splits[0] + block_index = pag_layer_input_splits[1] + attn_index = pag_layer_input_splits[2] + + for name, module in self.unet.named_modules(): + if ( + is_self_attn(name) + and get_block_type(name) == block_type + and get_block_index(name) == block_index + and get_attn_index(name) == attn_index + ): + target_modules.append(module) if len(target_modules) == 0: - raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") + raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") for module in target_modules: module.processor = pag_attn_proc @@ -146,64 +221,45 @@ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free cond = torch.cat([uncond, cond], dim=0) return cond - def set_pag_applied_layers( - self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Optional[Tuple[AttentionProcessor, AttentionProcessor]] = None, - self_attn_identifier: str = "attn1", - ): + def set_pag_applied_layers(self, pag_applied_layers): r""" - Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - - Args: - pag_applied_layers (`str` or `List[str]`): - One or more strings, or simple regex, to identify layers where to apply PAG. - pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, *optional*): - A tuple of two attention processors. The first attention processor is for PAG with Classifier-free - guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG - disabled (unconditional only). - self_attn_identifier (`str`, defaults to "attn1"): - The string to identity self-attn layers. + set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. """ if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] - if pag_attn_processors is not None: - if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: - raise ValueError("Expected a tuple of two attention processors") - for i in range(len(pag_applied_layers)): - if not isinstance(pag_applied_layers[i], str): - raise ValueError( - f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" - ) + for pag_layer in pag_applied_layers: + self._check_input_pag_applied_layer(pag_layer) self.pag_applied_layers = pag_applied_layers - self._self_attn_identifier = self_attn_identifier - - # Ensure we don't overwrite existing processors, possible from __init__, if the intention - # was to only change the layers where PAG was applied. - if pag_attn_processors is not None: - self._pag_attn_processors = pag_attn_processors @property def pag_scale(self): - r"""Get the scale factor for the perturbed attention guidance.""" + """ + Get the scale factor for the perturbed attention guidance. + """ return self._pag_scale @property def pag_adaptive_scale(self): - r"""Get the adaptive scale factor for the perturbed attention guidance.""" + """ + Get the adaptive scale factor for the perturbed attention guidance. + """ return self._pag_adaptive_scale @property def do_pag_adaptive_scaling(self): - r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" + """ + Check if the adaptive scaling is enabled for the perturbed attention guidance. + """ return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property def do_perturbed_attention_guidance(self): - r"""Check if the perturbed attention guidance is enabled.""" + """ + Check if the perturbed attention guidance is enabled. + """ return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property @@ -214,13 +270,190 @@ def pag_attn_processors(self): with the key as the name of the layer. """ - if not hasattr(self, "_pag_attn_processors") or self._pag_attn_processors is None: - valid_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0) + processors = {} + for name, proc in self.unet.attn_processors.items(): + if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): + processors[name] = proc + return processors + + +class PixArtPAGMixin: + @staticmethod + def _check_input_pag_applied_layer(layer): + r""" + Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}. + """ + + # Check if the layer index is valid (should be int or str of int) + if isinstance(layer, int): + return # Valid layer index + + if isinstance(layer, str): + if layer.isdigit(): + return # Valid layer index + + # If it is not a valid layer index, raise a ValueError + raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.") + + def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + if do_classifier_free_guidance: + pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() else: - valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) + pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() + + def is_self_attn(module_name): + r""" + Check if the module is self-attention module based on its name. + """ + return ( + "attn1" in module_name and len(module_name.split(".")) == 3 + ) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ... + + def get_block_index(module_name): + r""" + Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. + mid_block) and index is ommited from the name, it will be "block_0". + """ + # transformer_blocks.23.attn -> "23" + return module_name.split(".")[1] + + for pag_layer_input in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the transformer model + target_modules = [] + + block_index = str(pag_layer_input) + + for name, module in self.transformer.named_modules(): + if is_self_attn(name) and get_block_index(name) == block_index: + target_modules.append(module) + + if len(target_modules) == 0: + raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") + + for module in target_modules: + module.processor = pag_attn_proc + + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers + def set_pag_applied_layers(self, pag_applied_layers): + r""" + set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + """ + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + + for pag_layer in pag_applied_layers: + self._check_input_pag_applied_layer(pag_layer) + + self.pag_applied_layers = pag_applied_layers + + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale + def _get_pag_scale(self, t): + r""" + Get the scale factor for the perturbed attention guidance at timestep `t`. + """ + + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) + if signal_scale < 0: + signal_scale = 0 + return signal_scale + else: + return self.pag_scale + + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance + def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + r""" + Apply perturbed attention guidance to the noise prediction. + + Args: + noise_pred (torch.Tensor): The noise prediction tensor. + do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. + guidance_scale (float): The scale factor for the guidance term. + t (int): The current time step. + + Returns: + torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. + """ + pag_scale = self._get_pag_scale(t) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) + else: + noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) + noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + return noise_pred + + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance + def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): + """ + Prepares the perturbed attention guidance for the PAG model. + + Args: + cond (torch.Tensor): The conditional input tensor. + uncond (torch.Tensor): The unconditional input tensor. + do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. + + Returns: + torch.Tensor: The prepared perturbed attention guidance tensor. + """ + + cond = torch.cat([cond] * 2, dim=0) + + if do_classifier_free_guidance: + cond = torch.cat([uncond, cond], dim=0) + return cond + + @property + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale + def pag_scale(self): + """ + Get the scale factor for the perturbed attention guidance. + """ + return self._pag_scale + + @property + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale + def pag_adaptive_scale(self): + """ + Get the adaptive scale factor for the perturbed attention guidance. + """ + return self._pag_adaptive_scale + + @property + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling + def do_pag_adaptive_scaling(self): + """ + Check if the adaptive scaling is enabled for the perturbed attention guidance. + """ + return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance + def do_perturbed_attention_guidance(self): + """ + Check if the perturbed attention guidance is enabled. + """ + return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer + def pag_attn_processors(self): + r""" + Returns: + `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model + with the key as the name of the layer. + """ processors = {} - for name, proc in self.unet.attn_processors.items(): - if proc.__class__ in valid_attn_processors: + for name, proc in self.transformer.attn_processors.items(): + if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): processors[name] = proc return processors From 199aa10c6830949740ce99c7cf8bfd52edb1fd8b Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 12:19:47 +0200 Subject: [PATCH 13/22] add pag refactor back; update pixart sigma --- src/diffusers/pipelines/pag/pag_utils.py | 397 ++++-------------- .../pag/pipeline_pag_pixart_sigma.py | 4 +- 2 files changed, 84 insertions(+), 317 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 7c9bb2d098d2..afc8d1a2cecb 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re +from typing import List, Optional, Tuple, Union + import torch +import torch.nn as nn from ...models.attention_processor import ( + AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) @@ -25,140 +30,60 @@ class PAGMixin: - r"""Mixin class for PAG.""" - - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats: - "{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type` - can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be - in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}" - """ - - layer_splits = layer.split(".") - - if len(layer_splits) > 3: - raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") - - if len(layer_splits) >= 1: - if layer_splits[0] not in ["down", "mid", "up"]: - raise ValueError( - f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}" - ) - - if len(layer_splits) >= 2: - if not layer_splits[1].startswith("block_"): - raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'") - - if len(layer_splits) == 3: - layer_2 = layer_splits[2] - if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"): - raise ValueError( - f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'" - ) + r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): r""" Set the attention processor for the PAG layers. """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() + pag_attn_processors = getattr(self, "_pag_attn_processors", None) + if pag_attn_processors is None: + # If this hasn't been set by the user, we default to the original PAG identity processors + pag_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()) + + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + if hasattr(self, "unet"): + model: nn.Module = self.unet else: - pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() + model: nn.Module = self.transformer - def is_self_attn(module_name): + def is_self_attn(module_name: str) -> bool: r""" Check if the module is self-attention module based on its name. """ - return "attn1" in module_name and "to" not in name + attn_id = getattr(self, "_self_attn_identifier", "attn1") - def get_block_type(module_name): - r""" - Get the block type from the module name. Can be "down", "mid", "up". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" - # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down" - return module_name.split(".")[0].split("_")[0] + # here, we check for "to" because we want the attention processor itself and not the qkv layers + qkv_present = "to" in module_name + norm_present = "norm" in module_name + return attn_id in module_name and not qkv_present and not norm_present - def get_block_index(module_name): - r""" - Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0" - module_name_splits = module_name.split(".") - block_index = module_name_splits[1] - if "attentions" in block_index or "motion_modules" in block_index: - return "block_0" - else: - return f"block_{block_index}" - - def get_attn_index(module_name): - r""" - Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0", - "motion_modules_1", ... - """ - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - # mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" - # mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0" - module_name_split = module_name.split(".") - mid_name = module_name_split[1] - down_name = module_name_split[2] - if "attentions" in down_name: - return f"attentions_{module_name_split[3]}" - if "attentions" in mid_name: - return f"attentions_{module_name_split[2]}" - if "motion_modules" in down_name: - return f"motion_modules_{module_name_split[3]}" - if "motion_modules" in mid_name: - return f"motion_modules_{module_name_split[2]}" - - for pag_layer_input in pag_applied_layers: + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + for layer_id in pag_applied_layers: # for each PAG layer input, we find corresponding self-attention layers in the unet model target_modules = [] - pag_layer_input_splits = pag_layer_input.split(".") - - if len(pag_layer_input_splits) == 1: - # when the layer input only contains block_type. e.g. "mid", "down", "up" - block_type = pag_layer_input_splits[0] - for name, module in self.unet.named_modules(): - if is_self_attn(name) and get_block_type(name) == block_type: - target_modules.append(module) - - elif len(pag_layer_input_splits) == 2: - # when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - ): - target_modules.append(module) - - elif len(pag_layer_input_splits) == 3: - # when the layer input contains block_type, block_index and attention_index. - # e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1" - block_type = pag_layer_input_splits[0] - block_index = pag_layer_input_splits[1] - attn_index = pag_layer_input_splits[2] - - for name, module in self.unet.named_modules(): - if ( - is_self_attn(name) - and get_block_type(name) == block_type - and get_block_index(name) == block_index - and get_attn_index(name) == attn_index - ): - target_modules.append(module) + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(name) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + print(f"applying to: {name}") + target_modules.append(module) if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") for module in target_modules: module.processor = pag_attn_proc @@ -221,230 +146,67 @@ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free cond = torch.cat([uncond, cond], dim=0) return cond - def set_pag_applied_layers(self, pag_applied_layers): - r""" - set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - - for pag_layer in pag_applied_layers: - self._check_input_pag_applied_layer(pag_layer) - - self.pag_applied_layers = pag_applied_layers - - @property - def pag_scale(self): - """ - Get the scale factor for the perturbed attention guidance. - """ - return self._pag_scale - - @property - def pag_adaptive_scale(self): - """ - Get the adaptive scale factor for the perturbed attention guidance. - """ - return self._pag_adaptive_scale - - @property - def do_pag_adaptive_scaling(self): - """ - Check if the adaptive scaling is enabled for the perturbed attention guidance. - """ - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 - - @property - def do_perturbed_attention_guidance(self): - """ - Check if the perturbed attention guidance is enabled. - """ - return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 - - @property - def pag_attn_processors(self): + def set_pag_applied_layers( + self, + pag_applied_layers: Union[str, List[str]], + pag_attn_processors: Optional[Tuple[AttentionProcessor, AttentionProcessor]] = None, + self_attn_identifier: str = "attn1", + ): r""" - Returns: - `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model - with the key as the name of the layer. - """ - - processors = {} - for name, proc in self.unet.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): - processors[name] = proc - return processors - - -class PixArtPAGMixin: - @staticmethod - def _check_input_pag_applied_layer(layer): - r""" - Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}. - """ - - # Check if the layer index is valid (should be int or str of int) - if isinstance(layer, int): - return # Valid layer index - - if isinstance(layer, str): - if layer.isdigit(): - return # Valid layer index - - # If it is not a valid layer index, raise a ValueError - raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.") - - def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - if do_classifier_free_guidance: - pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() - else: - pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() - - def is_self_attn(module_name): - r""" - Check if the module is self-attention module based on its name. - """ - return ( - "attn1" in module_name and len(module_name.split(".")) == 3 - ) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ... + Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - def get_block_index(module_name): - r""" - Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. - mid_block) and index is ommited from the name, it will be "block_0". - """ - # transformer_blocks.23.attn -> "23" - return module_name.split(".")[1] - - for pag_layer_input in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the transformer model - target_modules = [] - - block_index = str(pag_layer_input) - - for name, module in self.transformer.named_modules(): - if is_self_attn(name) and get_block_index(name) == block_index: - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") - - for module in target_modules: - module.processor = pag_attn_proc - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers - def set_pag_applied_layers(self, pag_applied_layers): - r""" - set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + Args: + pag_applied_layers (`str` or `List[str]`): + One or more strings, or simple regex, to identify layers where to apply PAG. + pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, *optional*): + A tuple of two attention processors. The first attention processor is for PAG with Classifier-free + guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG + disabled (unconditional only). + self_attn_identifier (`str`, defaults to "attn1"): + The string to identity self-attn layers. """ if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") - for pag_layer in pag_applied_layers: - self._check_input_pag_applied_layer(pag_layer) + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) self.pag_applied_layers = pag_applied_layers + self._self_attn_identifier = self_attn_identifier - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale - def _get_pag_scale(self, t): - r""" - Get the scale factor for the perturbed attention guidance at timestep `t`. - """ - - if self.do_pag_adaptive_scaling: - signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) - if signal_scale < 0: - signal_scale = 0 - return signal_scale - else: - return self.pag_scale - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance - def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): - r""" - Apply perturbed attention guidance to the noise prediction. - - Args: - noise_pred (torch.Tensor): The noise prediction tensor. - do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. - guidance_scale (float): The scale factor for the guidance term. - t (int): The current time step. - - Returns: - torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. - """ - pag_scale = self._get_pag_scale(t) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - return noise_pred - - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance - def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): - """ - Prepares the perturbed attention guidance for the PAG model. - - Args: - cond (torch.Tensor): The conditional input tensor. - uncond (torch.Tensor): The unconditional input tensor. - do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. - - Returns: - torch.Tensor: The prepared perturbed attention guidance tensor. - """ - - cond = torch.cat([cond] * 2, dim=0) - - if do_classifier_free_guidance: - cond = torch.cat([uncond, cond], dim=0) - return cond + # Ensure we don't overwrite existing processors, possible from __init__, if the intention + # was to only change the layers where PAG was applied. + if pag_attn_processors is not None: + self._pag_attn_processors = pag_attn_processors @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale def pag_scale(self): - """ - Get the scale factor for the perturbed attention guidance. - """ + r"""Get the scale factor for the perturbed attention guidance.""" return self._pag_scale @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale def pag_adaptive_scale(self): - """ - Get the adaptive scale factor for the perturbed attention guidance. - """ + r"""Get the adaptive scale factor for the perturbed attention guidance.""" return self._pag_adaptive_scale @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling def do_pag_adaptive_scaling(self): - """ - Check if the adaptive scaling is enabled for the perturbed attention guidance. - """ + r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance def do_perturbed_attention_guidance(self): - """ - Check if the perturbed attention guidance is enabled. - """ + r"""Check if the perturbed attention guidance is enabled.""" return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer def pag_attn_processors(self): r""" Returns: @@ -452,8 +214,13 @@ def pag_attn_processors(self): with the key as the name of the layer. """ + if not hasattr(self, "_pag_attn_processors") or self._pag_attn_processors is None: + valid_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0) + else: + valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) + processors = {} - for name, proc in self.transformer.attn_processors.items(): - if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): + for name, proc in self.unet.attn_processors.items(): + if proc.__class__ in valid_attn_processors: processors[name] = proc return processors diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index 1188ffe52ed7..ae6496900dc1 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -40,7 +40,7 @@ ASPECT_RATIO_1024_BIN, ) from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN -from .pag_utils import PixArtPAGMixin +from .pag_utils import PAGMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -132,7 +132,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin): +class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): r""" [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation using PixArt-Sigma. From 5768906ec9a4361191b1153592ca49da2ebf3cea Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 12:48:02 +0200 Subject: [PATCH 14/22] update pixart pag tests --- tests/pipelines/pag/test_pag_pixart_sigma.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py index be86afe45be0..f485de7bb3bd 100644 --- a/tests/pipelines/pag/test_pag_pixart_sigma.py +++ b/tests/pipelines/pag/test_pag_pixart_sigma.py @@ -127,7 +127,7 @@ def test_pag_disable_enable(self): out = pipe(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 - components["pag_applied_layers"] = [1] + components["pag_applied_layers"] = ["1"] pipe_pag = self.pipeline_class(**components) pipe_pag = pipe_pag.to(device) pipe_pag.set_progress_bar_config(disable=None) @@ -158,7 +158,7 @@ def test_pag_applied_layers(self): # "attn1" should apply to all self-attention layers. all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] - pag_layers = [0, 1] + pag_layers = ["0", "1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) @@ -228,7 +228,7 @@ def test_save_load_optional_components(self): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["1"]) pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) @@ -282,7 +282,7 @@ def test_save_load_local(self, expected_max_difference=1e-4): pipe.save_pretrained(tmpdir, safe_serialization=False) with CaptureLogger(logger) as cap_logger: - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["1"]) for name in pipe_loaded.components.keys(): if name not in pipe_loaded._optional_components: From e754f5025da3e9a751c2cd7c0a493c3003288f6e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 13:09:18 +0200 Subject: [PATCH 15/22] apply suggestions from review Co-Authored-By: yixu310@gmail.com --- src/diffusers/pipelines/pag/pag_utils.py | 42 ++++++++++-------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index afc8d1a2cecb..da45d8b79be6 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -19,6 +19,7 @@ import torch.nn as nn from ...models.attention_processor import ( + Attention, AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, @@ -36,10 +37,9 @@ def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidanc r""" Set the attention processor for the PAG layers. """ - pag_attn_processors = getattr(self, "_pag_attn_processors", None) + pag_attn_processors = self._pag_attn_processors if pag_attn_processors is None: - # If this hasn't been set by the user, we default to the original PAG identity processors - pag_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()) + raise ValueError("No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters.") pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] @@ -48,16 +48,11 @@ def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidanc else: model: nn.Module = self.transformer - def is_self_attn(module_name: str) -> bool: + def is_self_attn(module: nn.Module) -> bool: r""" Check if the module is self-attention module based on its name. """ - attn_id = getattr(self, "_self_attn_identifier", "attn1") - - # here, we check for "to" because we want the attention processor itself and not the qkv layers - qkv_present = "to" in module_name - norm_present = "norm" in module_name - return attn_id in module_name and not qkv_present and not norm_present + return isinstance(module, Attention) and not module.is_cross_attention def is_fake_integral_match(layer_id, name): layer_id = layer_id.split(".")[-1] @@ -75,11 +70,11 @@ def is_fake_integral_match(layer_id, name): # (3) Make sure it's not a fake integral match if the layer_id ends with a number # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" if ( - is_self_attn(name) + is_self_attn(module) and re.search(layer_id, name) is not None and not is_fake_integral_match(layer_id, name) ): - print(f"applying to: {name}") + logger.debug(f"Apply PAG to layer: {name}") target_modules.append(module) if len(target_modules) == 0: @@ -149,8 +144,7 @@ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free def set_pag_applied_layers( self, pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Optional[Tuple[AttentionProcessor, AttentionProcessor]] = None, - self_attn_identifier: str = "attn1", + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()), ): r""" Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. @@ -158,7 +152,7 @@ def set_pag_applied_layers( Args: pag_applied_layers (`str` or `List[str]`): One or more strings, or simple regex, to identify layers where to apply PAG. - pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, *optional*): + pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG disabled (unconditional only). @@ -166,6 +160,9 @@ def set_pag_applied_layers( The string to identity self-attn layers. """ + if not hasattr(self, "_pag_attn_processors"): + self._pag_attn_processors = None + if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] if pag_attn_processors is not None: @@ -179,12 +176,7 @@ def set_pag_applied_layers( ) self.pag_applied_layers = pag_applied_layers - self._self_attn_identifier = self_attn_identifier - - # Ensure we don't overwrite existing processors, possible from __init__, if the intention - # was to only change the layers where PAG was applied. - if pag_attn_processors is not None: - self._pag_attn_processors = pag_attn_processors + self._pag_attn_processors = pag_attn_processors @property def pag_scale(self): @@ -214,10 +206,10 @@ def pag_attn_processors(self): with the key as the name of the layer. """ - if not hasattr(self, "_pag_attn_processors") or self._pag_attn_processors is None: - valid_attn_processors = (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0) - else: - valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) + if self._pag_attn_processors is None: + return {} + + valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) processors = {} for name, proc in self.unet.attn_processors.items(): From b05da7c78b6d8af415d428ae154c48dbb41018e3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Aug 2024 13:09:55 +0200 Subject: [PATCH 16/22] make style --- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/pag/pag_utils.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index f7d417346ff1..46e93ea7769b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -533,8 +533,8 @@ from .musicldm import MusicLDMPipeline from .pag import ( AnimateDiffPAGPipeline, - PixArtSigmaPAGPipeline, HunyuanDiTPAGPipeline, + PixArtSigmaPAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGPipeline, diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index da45d8b79be6..7c8b9677581e 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch import torch.nn as nn @@ -39,7 +39,9 @@ def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidanc """ pag_attn_processors = self._pag_attn_processors if pag_attn_processors is None: - raise ValueError("No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters.") + raise ValueError( + "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." + ) pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] @@ -144,7 +146,10 @@ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free def set_pag_applied_layers( self, pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0()), + pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), ): r""" Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. @@ -152,10 +157,11 @@ def set_pag_applied_layers( Args: pag_applied_layers (`str` or `List[str]`): One or more strings, or simple regex, to identify layers where to apply PAG. - pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0())`): - A tuple of two attention processors. The first attention processor is for PAG with Classifier-free - guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG - disabled (unconditional only). + pag_attn_processors: + (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). self_attn_identifier (`str`, defaults to "attn1"): The string to identity self-attn layers. """ @@ -208,7 +214,7 @@ def pag_attn_processors(self): if self._pag_attn_processors is None: return {} - + valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) processors = {} From 69c3250b001ee2a6c1ba438d0b905d72cb6a19df Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 15:42:14 +0200 Subject: [PATCH 17/22] update docs, fix tests --- docs/source/en/api/pipelines/pag.md | 13 +++++++++ src/diffusers/pipelines/pag/pag_utils.py | 25 +++++++++-------- .../pag/pipeline_pag_pixart_sigma.py | 4 +-- tests/pipelines/pag/test_pag_animatediff.py | 28 +++++++++++-------- tests/pipelines/pag/test_pag_hunyuan_dit.py | 12 ++++---- tests/pipelines/pag/test_pag_pixart_sigma.py | 8 +++--- 6 files changed, 55 insertions(+), 35 deletions(-) diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index 8bed30e5e27f..a1349cd0f9ef 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -20,6 +20,19 @@ The abstract from the paper is: *Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.* +PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers. + +- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor` +- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor` +- Partial identifier as a RegEx: `down_blocks.2`, or `attn1` +- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]` + + + +Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results. + + + ## AnimateDiffPAGPipeline [[autodoc]] AnimateDiffPAGPipeline - all diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 7c8b9677581e..b5eb319b16ef 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import torch import torch.nn as nn @@ -76,7 +76,7 @@ def is_fake_integral_match(layer_id, name): and re.search(layer_id, name) is not None and not is_fake_integral_match(layer_id, name) ): - logger.debug(f"Apply PAG to layer: {name}") + logger.debug(f"Applying PAG to layer: {name}") target_modules.append(module) if len(target_modules) == 0: @@ -156,14 +156,17 @@ def set_pag_applied_layers( Args: pag_applied_layers (`str` or `List[str]`): - One or more strings, or simple regex, to identify layers where to apply PAG. + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG disabled (unconditional only). - self_attn_identifier (`str`, defaults to "attn1"): - The string to identity self-attn layers. """ if not hasattr(self, "_pag_attn_processors"): @@ -185,27 +188,27 @@ def set_pag_applied_layers( self._pag_attn_processors = pag_attn_processors @property - def pag_scale(self): + def pag_scale(self) -> float: r"""Get the scale factor for the perturbed attention guidance.""" return self._pag_scale @property - def pag_adaptive_scale(self): + def pag_adaptive_scale(self) -> float: r"""Get the adaptive scale factor for the perturbed attention guidance.""" return self._pag_adaptive_scale @property - def do_pag_adaptive_scaling(self): + def do_pag_adaptive_scaling(self) -> bool: r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - def do_perturbed_attention_guidance(self): + def do_perturbed_attention_guidance(self) -> bool: r"""Check if the perturbed attention guidance is enabled.""" return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property - def pag_attn_processors(self): + def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model @@ -215,7 +218,7 @@ def pag_attn_processors(self): if self._pag_attn_processors is None: return {} - valid_attn_processors = tuple(x.__class__ for x in self._pag_attn_processors) + valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} processors = {} for name, proc in self.unet.attn_processors.items(): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index ae6496900dc1..8e5e6cbaf5ad 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -61,7 +61,7 @@ >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", ... torch_dtype=torch.float16, - ... pag_applied_layers=[14], + ... pag_applied_layers=["blocks.14"], ... enable_pag=True, ... ) >>> pipe = pipe.to("cuda") @@ -164,7 +164,7 @@ def __init__( vae: AutoencoderKL, transformer: PixArtTransformer2DModel, scheduler: KarrasDiffusionSchedulers, - pag_applied_layers: Union[str, List[str]] = "1", # 1st transformer block + pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block ): super().__init__() diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 8f637b991056..2cd80921e932 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -429,7 +429,10 @@ def test_pag_applied_layers(self): pipe.set_progress_bar_config(disable=None) # pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers - all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k] + # Note that for motion modules in AnimateDiff, both attn1 and attn2 are self-attention + all_self_attn_layers = [ + k for k in pipe.unet.attn_processors.keys() if "attn1" in k or ("motion_modules" in k and "attn2" in k) + ] original_attn_procs = pipe.unet.attn_processors pag_layers = [ "down", @@ -439,12 +442,13 @@ def test_pag_applied_layers(self): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) - # pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.motion_modules_0"] should apply to all self-attention layers in mid_block, i.e. + # pag_applied_layers = ["mid"], or ["mid_block.0"] should apply to all self-attention layers in mid_block, i.e. # mid_block.motion_modules.0.transformer_blocks.0.attn1.processor # mid_block.attentions.0.transformer_blocks.0.attn1.processor all_self_attn_mid_layers = [ - "mid_block.motion_modules.0.transformer_blocks.0.attn1.processor", "mid_block.attentions.0.transformer_blocks.0.attn1.processor", + "mid_block.motion_modules.0.transformer_blocks.0.attn1.processor", + "mid_block.motion_modules.0.transformer_blocks.0.attn2.processor", ] pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["mid"] @@ -452,17 +456,17 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0", "mid.block_0.motion_modules_0"] + pag_layers = ["mid_block.(attentions|motion_modules)"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -474,19 +478,19 @@ def test_pag_applied_layers(self): pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["down"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert len(pipe.pag_attn_processors) == 6 + assert len(pipe.pag_attn_processors) == 10 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert (len(pipe.pag_attn_processors)) == 4 + assert (len(pipe.pag_attn_processors)) == 6 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert len(pipe.pag_attn_processors) == 2 + assert len(pipe.pag_attn_processors) == 10 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.motion_modules_2"] + pag_layers = ["motion_modules.42"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py index 549d8e8a7630..e6317e63cc64 100644 --- a/tests/pipelines/pag/test_pag_hunyuan_dit.py +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -83,7 +83,7 @@ def get_dummy_components(self): "tokenizer_2": tokenizer_2, "safety_checker": None, "feature_extractor": None, - "pag_applied_layers": [1], + "pag_applied_layers": ["blocks.1"], } return components @@ -336,28 +336,28 @@ def test_pag_applied_layers(self): all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] original_attn_procs = pipe.transformer.attn_processors - pag_layers = [0, 1] + pag_layers = ["blocks.0", "blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) # blocks.0 block_0_self_attn = ["blocks.0.attn1.processor"] pipe.transformer.set_attn_processor(original_attn_procs.copy()) - pag_layers = [0] + pag_layers = ["blocks.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(block_0_self_attn) pipe.transformer.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["blocks.0.attn1"] + pag_layers = "blocks.0.attn1" pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(block_0_self_attn) pipe.transformer.set_attn_processor(original_attn_procs.copy()) - pag_layers = [0, "1"] + pag_layers = "blocks.(0|1)" pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert (len(pipe.pag_attn_processors)) == 2 pipe.transformer.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["0", "blocks.1"] + pag_layers = ["blocks.0", r"blocks\.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2 diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py index f485de7bb3bd..70b528dede56 100644 --- a/tests/pipelines/pag/test_pag_pixart_sigma.py +++ b/tests/pipelines/pag/test_pag_pixart_sigma.py @@ -127,7 +127,7 @@ def test_pag_disable_enable(self): out = pipe(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 - components["pag_applied_layers"] = ["1"] + components["pag_applied_layers"] = ["blocks.1"] pipe_pag = self.pipeline_class(**components) pipe_pag = pipe_pag.to(device) pipe_pag.set_progress_bar_config(disable=None) @@ -158,7 +158,7 @@ def test_pag_applied_layers(self): # "attn1" should apply to all self-attention layers. all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] - pag_layers = ["0", "1"] + pag_layers = ["blocks.0", "blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) @@ -228,7 +228,7 @@ def test_save_load_optional_components(self): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["1"]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"]) pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) @@ -282,7 +282,7 @@ def test_save_load_local(self, expected_max_difference=1e-4): pipe.save_pretrained(tmpdir, safe_serialization=False) with CaptureLogger(logger) as cap_logger: - pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["1"]) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"]) for name in pipe_loaded.components.keys(): if name not in pipe_loaded._optional_components: From 753a023018f549c9f2c5a5f7d69115ecaddf1d9a Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 21:54:47 +0200 Subject: [PATCH 18/22] fix tests --- docs/source/en/api/pipelines/pag.md | 2 +- .../pipelines/pag/pipeline_pag_sd_animatediff.py | 2 +- tests/pipelines/pag/test_pag_sd.py | 12 ++++++------ tests/pipelines/pag/test_pag_sdxl.py | 12 ++++++------ 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index a1349cd0f9ef..ac12bdb5578d 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -77,4 +77,4 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial ## PixArtSigmaPAGPipeline [[autodoc]] PixArtSigmaPAGPipeline - all - - __call__ \ No newline at end of file + - __call__ diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index e37506a60c61..b3b103742061 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -129,7 +129,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, - pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1"], ["up.block_0.attentions_0"] + pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"] ): super().__init__() if isinstance(unet, UNet2DConditionModel): diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py index a0930245b375..e9adb3ac447e 100644 --- a/tests/pipelines/pag/test_pag_sd.py +++ b/tests/pipelines/pag/test_pag_sd.py @@ -213,18 +213,18 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0"] + pag_layers = ["mid_block.attentions.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -239,17 +239,17 @@ def test_pag_applied_layers(self): assert len(pipe.pag_attn_processors) == 2 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["down_blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.attentions_1"] + pag_layers = ["down_blocks.1.attentions.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 1 diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py index 5ec3dc5555f1..589573385677 100644 --- a/tests/pipelines/pag/test_pag_sdxl.py +++ b/tests/pipelines/pag/test_pag_sdxl.py @@ -225,18 +225,18 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0"] + pag_layers = ["mid_block"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_0"] + pag_layers = ["mid_block.attentions.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["mid.block_0.attentions_1"] + pag_layers = ["mid_block.attentions.1"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) @@ -251,17 +251,17 @@ def test_pag_applied_layers(self): assert len(pipe.pag_attn_processors) == 4 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_0"] + pag_layers = ["down_blocks.0"] with self.assertRaises(ValueError): pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1"] + pag_layers = ["down_blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 4 pipe.unet.set_attn_processor(original_attn_procs.copy()) - pag_layers = ["down.block_1.attentions_1"] + pag_layers = ["down_blocks.1.attentions.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2 From 6f549d12589532da76cab3f51e2fe39406b62136 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 4 Aug 2024 22:12:54 +0200 Subject: [PATCH 19/22] fix test_components_function since list not accepted as valid __init__ param --- src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 36d4146860a7..3a42beb357ae 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -201,7 +201,7 @@ def __init__( requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, tokenizer_2: Optional[MT5Tokenizer] = None, - pag_applied_layers: Union[str, List[str]] = [], # "blocks.16.attn1", "blocks.16", "16", 16 + pag_applied_layers: Union[str, List[str]] = "blocks.21", # "blocks.16.attn1", "blocks.16", "16", 16 ): super().__init__() From de99a59ec66744c1ee9665ceafaae53e24282b65 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 5 Aug 2024 12:28:00 +0200 Subject: [PATCH 20/22] apply patch to fix broken tests Co-Authored-By: Sayak Paul --- src/diffusers/pipelines/pag/pag_utils.py | 12 +++++++++++- .../pipelines/pag/pipeline_pag_hunyuandit.py | 2 +- tests/pipelines/pag/test_pag_hunyuan_dit.py | 5 ++--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index b5eb319b16ef..7a6e79a2c9ff 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -221,7 +221,17 @@ def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} processors = {} - for name, proc in self.unet.attn_processors.items(): + # We could have iterated through the self.components.items() and checked if a component is + # `ModelMixin` subclassed but that can include a VAE too. + if hasattr(self, "unet"): + denoiser_module = self.unet + elif hasattr(self, "transformer"): + denoiser_module = self.transformer + else: + raise ValueError("No denoiser module found.") + + for name, proc in denoiser_module.attn_processors.items(): if proc.__class__ in valid_attn_processors: processors[name] = proc + return processors diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 3a42beb357ae..63126cc5aae9 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -201,7 +201,7 @@ def __init__( requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, tokenizer_2: Optional[MT5Tokenizer] = None, - pag_applied_layers: Union[str, List[str]] = "blocks.21", # "blocks.16.attn1", "blocks.16", "16", 16 + pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 ): super().__init__() diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py index e6317e63cc64..29950984b794 100644 --- a/tests/pipelines/pag/test_pag_hunyuan_dit.py +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -83,7 +83,6 @@ def get_dummy_components(self): "tokenizer_2": tokenizer_2, "safety_checker": None, "feature_extractor": None, - "pag_applied_layers": ["blocks.1"], } return components @@ -348,12 +347,12 @@ def test_pag_applied_layers(self): assert set(pipe.pag_attn_processors) == set(block_0_self_attn) pipe.transformer.set_attn_processor(original_attn_procs.copy()) - pag_layers = "blocks.0.attn1" + pag_layers = ["blocks.0.attn1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(block_0_self_attn) pipe.transformer.set_attn_processor(original_attn_procs.copy()) - pag_layers = "blocks.(0|1)" + pag_layers = ["blocks.(0|1)"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert (len(pipe.pag_attn_processors)) == 2 From 473645da0ee7600ef447ca0da6ef6961f51d7f81 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 5 Aug 2024 12:28:33 +0200 Subject: [PATCH 21/22] make style --- src/diffusers/pipelines/pag/pag_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 7a6e79a2c9ff..728f730c9904 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -229,9 +229,9 @@ def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: denoiser_module = self.transformer else: raise ValueError("No denoiser module found.") - + for name, proc in denoiser_module.attn_processors.items(): if proc.__class__ in valid_attn_processors: processors[name] = proc - + return processors From 29f54e46de80a6631e93aff6574a222994f2cff6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 5 Aug 2024 12:44:30 +0200 Subject: [PATCH 22/22] fix hunyuan tests --- tests/pipelines/pag/test_pag_hunyuan_dit.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py index 29950984b794..db0e257760ed 100644 --- a/tests/pipelines/pag/test_pag_hunyuan_dit.py +++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py @@ -190,9 +190,7 @@ def test_save_load_optional_components(self): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) - pipe_loaded = self.pipeline_class.from_pretrained( - tmpdir, pag_applied_layers=components["pag_applied_layers"] - ) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) @@ -288,7 +286,6 @@ def test_pag_disable_enable(self): components = self.get_dummy_components() # base pipeline (expect same output when pag is disabled) - components.pop("pag_applied_layers", None) pipe_sd = HunyuanDiTPipeline(**components) pipe_sd = pipe_sd.to(device) pipe_sd.set_progress_bar_config(disable=None) @@ -328,7 +325,6 @@ def test_pag_applied_layers(self): components = self.get_dummy_components() # base pipeline - components.pop("pag_applied_layers", None) pipe = self.pipeline_class(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None)