From a6a0429089f0fde250de0fa98b6b62d4a1c77b72 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 13 May 2024 08:40:19 +0200 Subject: [PATCH 01/44] first draft --- .../pipeline_stable_diffusion_xl.py | 429 +++++++++++++++++- 1 file changed, 422 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 52d0b07fb315..0e4d5725ff07 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -825,6 +825,34 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + @property + def pag_scale(self): + return self._pag_scale + + @property + def do_adversarial_guidance(self): + return self._pag_scale > 0 + + @property + def pag_adaptive_scaling(self): + return self._pag_adaptive_scaling + + @property + def do_pag_adaptive_scaling(self): + return self._pag_adaptive_scaling > 0 + + @property + def pag_drop_rate(self): + return self._pag_drop_rate + + @property + def pag_applied_layers(self): + return self._pag_applied_layers + + @property + def pag_applied_layers_index(self): + return self._pag_applied_layers_index + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -838,6 +866,11 @@ def __call__( sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, + pag_scale: float = 0.0, + pag_adaptive_scaling: float = 0.0, + pag_drop_rate: float = 0.5, + pag_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up'] + pag_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0'] negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -1058,6 +1091,12 @@ def __call__( self._denoising_end = denoising_end self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scaling = pag_adaptive_scaling + self._pag_drop_rate = pag_drop_rate + self._pag_applied_layers = pag_applied_layers + self._pag_applied_layers_index = pag_applied_layers_index + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1140,11 +1179,22 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: + #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids, add_time_ids], dim=0) + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -1185,15 +1235,87 @@ def __call__( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) + # 10. Create down mid and up layer lists + if self.do_adversarial_guidance: + down_layers = [] + mid_layers = [] + up_layers = [] + for name, module in self.unet.named_modules(): + if 'attn1' in name and 'to' not in name: + layer_type = name.split('.')[0].split('_')[0] + if layer_type == 'down': + down_layers.append(module) + elif layer_type == 'mid': + mid_layers.append(module) + elif layer_type == 'up': + up_layers.append(module) + else: + raise ValueError(f"Invalid layer type: {layer_type}") + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - + #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 3) + #no + else: + latent_model_input = latents + + # change attention layer in UNet if use PAG + if self.do_adversarial_guidance: + + if self.do_classifier_free_guidance: + replace_processor = PAGCFGIdentitySelfAttnProcessor() + else: + replace_processor = PAGIdentitySelfAttnProcessor() + + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = replace_processor + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = replace_processor + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual @@ -1211,10 +1333,34 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance: + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + # pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) + + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + + noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) + + # both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + + noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) + + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + + noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb) + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) @@ -1301,4 +1447,273 @@ def __call__( if not return_dict: return (image,) + #Change the attention layers back to original ones after PAG was applied + if self.do_adversarial_guidance: + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) return StableDiffusionXLPipelineOutput(images=image) + + + +from diffusers.models.attention_processor import Attention, AttnProcessor2_0 +import torch.nn.functional as F + +class PAGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + + # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) + hidden_states_ptb = value + + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + + +class PAGCFGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + hidden_states_ptb = value + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file From 3605df9c0c952d0b82957527e56b32bbaef4fe14 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 03:05:35 +0200 Subject: [PATCH 02/44] refactor --- src/diffusers/models/attention_processor.py | 228 ++++++++++++ src/diffusers/pipelines/pag_utils.py | 180 +++++++++ .../pipeline_stable_diffusion_xl.py | 349 +----------------- 3 files changed, 414 insertions(+), 343 deletions(-) create mode 100644 src/diffusers/pipelines/pag_utils.py diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cbb07eafa37f..20f5dbebf3ed 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2651,6 +2651,232 @@ def __call__( return hidden_states +class PAGIdentitySelfAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + + # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) + hidden_states_ptb = value + + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class PAGCFGIdentitySelfAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + hidden_states_ptb = value + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, @@ -2691,6 +2917,8 @@ def __call__( CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, # deprecated LoRAAttnProcessor, LoRAAttnProcessor2_0, diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py new file mode 100644 index 000000000000..ee2f677984a8 --- /dev/null +++ b/src/diffusers/pipelines/pag_utils.py @@ -0,0 +1,180 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple, Union, List + +import torch +import torch.fft as fft + +from ..utils.torch_utils import randn_tensor +from ..models.attention_processor import PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, AttnProcessor2_0 + + +class PAGMixin: + r"""Mixin class for PAG.""" + + def enable_pag( + self, + pag_scale: float = 0.0, + pag_adaptive_scaling: float = 0.0, + pag_drop_rate: float = 0.5, + pag_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up'] + pag_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0'] + ): + """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. + + This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). + + Args: + pag_scale (`float`, *optional*, defaults to `0.0`): + Guidance scale of PAG. + """ + self._pag_scale = pag_scale + self._pag_adaptive_scaling = pag_adaptive_scaling + self._pag_drop_rate = pag_drop_rate + self._pag_applied_layers = pag_applied_layers + self._pag_applied_layers_index = pag_applied_layers_index + + def _get_self_attn_layers(self): + down_layers = [] + mid_layers = [] + up_layers = [] + for name, module in self.unet.named_modules(): + if 'attn1' in name and 'to' not in name: + layer_type = name.split('.')[0].split('_')[0] + if layer_type == 'down': + down_layers.append(module) + elif layer_type == 'mid': + mid_layers.append(module) + elif layer_type == 'up': + up_layers.append(module) + else: + raise ValueError(f"Invalid layer type: {layer_type}") + return up_layers, mid_layers, down_layers + def set_pag_attn_processor(self): + up_layers, mid_layers, down_layers = self._get_self_attn_layers() + + if self.do_classifier_free_guidance: + replace_processor = PAGCFGIdentitySelfAttnProcessor2_0() + else: + replace_processor = PAGIdentitySelfAttnProcessor2_0() + + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = replace_processor + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = replace_processor + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) + + def disable_pag(self): + """Disables the PAG mechanism if enabled.""" + up_layers, mid_layers, down_layers = self._get_self_attn_layers() + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) + + + + @property + def pag_scale(self): + return self._pag_scale + + @property + def do_adversarial_guidance(self): + return self._pag_scale > 0 + + @property + def pag_adaptive_scaling(self): + return self._pag_adaptive_scaling + + @property + def do_pag_adaptive_scaling(self): + return self._pag_adaptive_scaling > 0 + + @property + def pag_drop_rate(self): + return self._pag_drop_rate + + @property + def pag_applied_layers(self): + return self._pag_applied_layers + + @property + def pag_applied_layers_index(self): + return self._pag_applied_layers_index + diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 0e4d5725ff07..01a1a5f01738 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -54,6 +54,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput +from ..pag_utils import PAGMixin if is_invisible_watermark_available(): @@ -168,6 +169,7 @@ class StableDiffusionXLPipeline( StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, + PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -866,11 +868,6 @@ def __call__( sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, - pag_scale: float = 0.0, - pag_adaptive_scaling: float = 0.0, - pag_drop_rate: float = 0.5, - pag_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up'] - pag_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0'] negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -1090,12 +1087,6 @@ def __call__( self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False - - self._pag_scale = pag_scale - self._pag_adaptive_scaling = pag_adaptive_scaling - self._pag_drop_rate = pag_drop_rate - self._pag_applied_layers = pag_applied_layers - self._pag_applied_layers_index = pag_applied_layers_index # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1180,6 +1171,9 @@ def __call__( negative_add_time_ids = add_time_ids #cfg + if self.do_adversarial_guidance: + self.set_pag_attn_processor() + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1234,23 +1228,6 @@ def __call__( timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) - - # 10. Create down mid and up layer lists - if self.do_adversarial_guidance: - down_layers = [] - mid_layers = [] - up_layers = [] - for name, module in self.unet.named_modules(): - if 'attn1' in name and 'to' not in name: - layer_type = name.split('.')[0].split('_')[0] - if layer_type == 'down': - down_layers.append(module) - elif layer_type == 'mid': - mid_layers.append(module) - elif layer_type == 'up': - up_layers.append(module) - else: - raise ValueError(f"Invalid layer type: {layer_type}") self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1270,51 +1247,6 @@ def __call__( #no else: latent_model_input = latents - - # change attention layer in UNet if use PAG - if self.do_adversarial_guidance: - - if self.do_classifier_free_guidance: - replace_processor = PAGCFGIdentitySelfAttnProcessor() - else: - replace_processor = PAGIdentitySelfAttnProcessor() - - if(self.pag_applied_layers_index): - drop_layers = self.pag_applied_layers_index - for drop_layer in drop_layers: - layer_number = int(drop_layer[1:]) - try: - if drop_layer[0] == 'd': - down_layers[layer_number].processor = replace_processor - elif drop_layer[0] == 'm': - mid_layers[layer_number].processor = replace_processor - elif drop_layer[0] == 'u': - up_layers[layer_number].processor = replace_processor - else: - raise ValueError(f"Invalid layer type: {drop_layer[0]}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." - ) - elif(self.pag_applied_layers): - drop_full_layers = self.pag_applied_layers - for drop_full_layer in drop_full_layers: - try: - if drop_full_layer == "down": - for down_layer in down_layers: - down_layer.processor = replace_processor - elif drop_full_layer == "mid": - for mid_layer in mid_layers: - mid_layer.processor = replace_processor - elif drop_full_layer == "up": - for up_layer in up_layers: - up_layer.processor = replace_processor - else: - raise ValueError(f"Invalid layer type: {drop_full_layer}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" - ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1447,273 +1379,4 @@ def __call__( if not return_dict: return (image,) - #Change the attention layers back to original ones after PAG was applied - if self.do_adversarial_guidance: - if(self.pag_applied_layers_index): - drop_layers = self.pag_applied_layers_index - for drop_layer in drop_layers: - layer_number = int(drop_layer[1:]) - try: - if drop_layer[0] == 'd': - down_layers[layer_number].processor = AttnProcessor2_0() - elif drop_layer[0] == 'm': - mid_layers[layer_number].processor = AttnProcessor2_0() - elif drop_layer[0] == 'u': - up_layers[layer_number].processor = AttnProcessor2_0() - else: - raise ValueError(f"Invalid layer type: {drop_layer[0]}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." - ) - elif(self.pag_applied_layers): - drop_full_layers = self.pag_applied_layers - for drop_full_layer in drop_full_layers: - try: - if drop_full_layer == "down": - for down_layer in down_layers: - down_layer.processor = AttnProcessor2_0() - elif drop_full_layer == "mid": - for mid_layer in mid_layers: - mid_layer.processor = AttnProcessor2_0() - elif drop_full_layer == "up": - for up_layer in up_layers: - up_layer.processor = AttnProcessor2_0() - else: - raise ValueError(f"Invalid layer type: {drop_full_layer}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" - ) - return StableDiffusionXLPipelineOutput(images=image) - - - -from diffusers.models.attention_processor import Attention, AttnProcessor2_0 -import torch.nn.functional as F - -class PAGIdentitySelfAttnProcessor: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - # chunk - hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) - - # original path - batch_size, sequence_length, _ = hidden_states_org.shape - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states_org) - key = attn.to_k(hidden_states_org) - value = attn.to_v(hidden_states_org) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states_org = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_org = hidden_states_org.to(query.dtype) - - # linear proj - hidden_states_org = attn.to_out[0](hidden_states_org) - # dropout - hidden_states_org = attn.to_out[1](hidden_states_org) - - if input_ndim == 4: - hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # perturbed path (identity attention) - batch_size, sequence_length, _ = hidden_states_ptb.shape - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) - - value = attn.to_v(hidden_states_ptb) - - # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) - hidden_states_ptb = value - - hidden_states_ptb = hidden_states_ptb.to(query.dtype) - - # linear proj - hidden_states_ptb = attn.to_out[0](hidden_states_ptb) - # dropout - hidden_states_ptb = attn.to_out[1](hidden_states_ptb) - - if input_ndim == 4: - hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # cat - hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - - -class PAGCFGIdentitySelfAttnProcessor: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - # chunk - hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) - hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) - - # original path - batch_size, sequence_length, _ = hidden_states_org.shape - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states_org) - key = attn.to_k(hidden_states_org) - value = attn.to_v(hidden_states_org) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states_org = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_org = hidden_states_org.to(query.dtype) - - # linear proj - hidden_states_org = attn.to_out[0](hidden_states_org) - # dropout - hidden_states_org = attn.to_out[1](hidden_states_org) - - if input_ndim == 4: - hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # perturbed path (identity attention) - batch_size, sequence_length, _ = hidden_states_ptb.shape - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) - - value = attn.to_v(hidden_states_ptb) - hidden_states_ptb = value - hidden_states_ptb = hidden_states_ptb.to(query.dtype) - - # linear proj - hidden_states_ptb = attn.to_out[0](hidden_states_ptb) - # dropout - hidden_states_ptb = attn.to_out[1](hidden_states_ptb) - - if input_ndim == 4: - hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # cat - hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file From f94376c605f90f4224604e6d7ddb5ffd349ec7b8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 12:17:32 +0200 Subject: [PATCH 03/44] update --- src/diffusers/pipelines/pag_utils.py | 61 +++++++--- .../pipeline_stable_diffusion_xl.py | 107 ++++-------------- 2 files changed, 70 insertions(+), 98 deletions(-) diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py index ee2f677984a8..6a6ea6b13f85 100644 --- a/src/diffusers/pipelines/pag_utils.py +++ b/src/diffusers/pipelines/pag_utils.py @@ -12,13 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Tuple, Union, List - -import torch -import torch.fft as fft - -from ..utils.torch_utils import randn_tensor +from typing import Tuple from ..models.attention_processor import PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, AttnProcessor2_0 @@ -30,8 +24,9 @@ def enable_pag( pag_scale: float = 0.0, pag_adaptive_scaling: float = 0.0, pag_drop_rate: float = 0.5, - pag_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up'] - pag_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0'] + pag_applied_layers: Tuple[str] = ('mid',), #('down', 'mid', 'up',) + pag_applied_layers_index: Tuple[str] = None, #('d4', 'd5', 'm0',) + pag_cfg: bool = True, ): """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. @@ -46,6 +41,9 @@ def enable_pag( self._pag_drop_rate = pag_drop_rate self._pag_applied_layers = pag_applied_layers self._pag_applied_layers_index = pag_applied_layers_index + self._pag_cfg = pag_cfg + + self._set_pag_attn_processor() def _get_self_attn_layers(self): down_layers = [] @@ -63,10 +61,11 @@ def _get_self_attn_layers(self): else: raise ValueError(f"Invalid layer type: {layer_type}") return up_layers, mid_layers, down_layers - def set_pag_attn_processor(self): + + def _set_pag_attn_processor(self): up_layers, mid_layers, down_layers = self._get_self_attn_layers() - if self.do_classifier_free_guidance: + if self._pag_cfg: replace_processor = PAGCFGIdentitySelfAttnProcessor2_0() else: replace_processor = PAGIdentitySelfAttnProcessor2_0() @@ -107,9 +106,31 @@ def set_pag_attn_processor(self): raise ValueError( f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" ) + + def _get_pag_scale(self, t): + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + return signal_scale + else: + return self.pag_scale + + def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + pag_scale = self._get_pag_scale(t) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + pag_scale * (noise_pred_text - noise_pred_perturb) + else: + noise_pred_uncond, noise_pred_perturb = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + pag_scale * (noise_pred_uncond - noise_pred_perturb) + return noise_pred def disable_pag(self): """Disables the PAG mechanism if enabled.""" + if not self.do_perturbed_attention_guidance: + raise ValueError("PAG is not enabled.") + up_layers, mid_layers, down_layers = self._get_self_attn_layers() if(self.pag_applied_layers_index): drop_layers = self.pag_applied_layers_index @@ -147,17 +168,21 @@ def disable_pag(self): raise ValueError( f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" ) - - + self._pag_scale = None + self._pag_adaptive_scaling = None + self._pag_drop_rate = None + self._pag_applied_layers = None + self._pag_applied_layers_index = None + self._pag_cfg = None @property def pag_scale(self): return self._pag_scale @property - def do_adversarial_guidance(self): - return self._pag_scale > 0 - + def pag_cfg(self): + return self._pag_cfg + @property def pag_adaptive_scaling(self): return self._pag_adaptive_scaling @@ -178,3 +203,7 @@ def pag_applied_layers(self): def pag_applied_layers_index(self): return self._pag_applied_layers_index + @property + def do_perturbed_attention_guidance(self): + return hasattr(self, "_pag_scale") and self._pag_scale is not None and self._pag_scale > 0 + diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 01a1a5f01738..fdfd267ba8b1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -559,7 +559,8 @@ def prepare_ip_adapter_image_embeds( single_negative_image_embeds = torch.stack( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) - + if self.do_perturbed_attention_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) if do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) @@ -611,6 +612,7 @@ def check_inputs( height, width, callback_steps, + guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -699,6 +701,11 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) + + if hasattr(self, "_pag_cfg") and self.pag_cfg == False and guidance_scale != 0: + raise ValueError( + F"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): @@ -827,34 +834,6 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - @property - def pag_scale(self): - return self._pag_scale - - @property - def do_adversarial_guidance(self): - return self._pag_scale > 0 - - @property - def pag_adaptive_scaling(self): - return self._pag_adaptive_scaling - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scaling > 0 - - @property - def pag_drop_rate(self): - return self._pag_drop_rate - - @property - def pag_applied_layers(self): - return self._pag_applied_layers - - @property - def pag_applied_layers_index(self): - return self._pag_applied_layers_index - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1070,6 +1049,7 @@ def __call__( height, width, callback_steps, + guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1171,23 +1151,15 @@ def __call__( negative_add_time_ids = add_time_ids #cfg - if self.do_adversarial_guidance: - self.set_pag_attn_processor() - - if self.do_classifier_free_guidance and not self.do_adversarial_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - #pag - elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + if self.do_perturbed_attention_guidance: prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - #both - elif self.do_classifier_free_guidance and self.do_adversarial_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids, add_time_ids], dim=0) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) @@ -1235,19 +1207,9 @@ def __call__( if self.interrupt: continue - #cfg - if self.do_classifier_free_guidance and not self.do_adversarial_guidance: - latent_model_input = torch.cat([latents] * 2) - #pag - elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: - latent_model_input = torch.cat([latents] * 2) - #both - elif self.do_classifier_free_guidance and self.do_adversarial_guidance: - latent_model_input = torch.cat([latents] * 3) - #no - else: - latent_model_input = latents - + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] //latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual @@ -1265,34 +1227,15 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance(noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t) + + elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - # pag - elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: - noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) - - signal_scale = self.pag_scale - if self.do_pag_adaptive_scaling: - signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) - if signal_scale<0: - signal_scale = 0 - - noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) - - # both - elif self.do_classifier_free_guidance and self.do_adversarial_guidance: - - noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) - - signal_scale = self.pag_scale - if self.do_pag_adaptive_scaling: - signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) - if signal_scale<0: - signal_scale = 0 - - noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb) - + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) From 54c3fd6304c9e003a1745b04c1eea515d524dab3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 12:19:26 +0200 Subject: [PATCH 04/44] up --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index fdfd267ba8b1..d929574769a8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -702,7 +702,7 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - if hasattr(self, "_pag_cfg") and self.pag_cfg == False and guidance_scale != 0: + if hasattr(self, "_pag_cfg") and self.pag_cfg is False and guidance_scale != 0: raise ValueError( F"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." ) From f571430a9db7146dc3da8df7134774ef9a23ecef Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 12:32:02 +0200 Subject: [PATCH 05/44] style --- src/diffusers/models/attention_processor.py | 26 ++--- src/diffusers/pipelines/pag_utils.py | 109 ++++++++++-------- .../pipeline_stable_diffusion_xl.py | 25 ++-- 3 files changed, 84 insertions(+), 76 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 20f5dbebf3ed..bdeea8db81b6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2673,7 +2673,7 @@ def __call__( if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) - + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2682,10 +2682,10 @@ def __call__( if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - + # chunk hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) - + # original path batch_size, sequence_length, _ = hidden_states_org.shape @@ -2718,7 +2718,7 @@ def __call__( hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query.dtype) - + # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout @@ -2740,12 +2740,12 @@ def __call__( hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) value = attn.to_v(hidden_states_ptb) - + # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) hidden_states_ptb = value - + hidden_states_ptb = hidden_states_ptb.to(query.dtype) - + # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout @@ -2787,7 +2787,7 @@ def __call__( if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) - + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2796,11 +2796,11 @@ def __call__( if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - + # chunk hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) - + # original path batch_size, sequence_length, _ = hidden_states_org.shape @@ -2812,7 +2812,7 @@ def __call__( if attn.group_norm is not None: hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) - + query = attn.to_q(hidden_states_org) key = attn.to_k(hidden_states_org) value = attn.to_v(hidden_states_org) @@ -2833,7 +2833,7 @@ def __call__( hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query.dtype) - + # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout @@ -2857,7 +2857,7 @@ def __call__( value = attn.to_v(hidden_states_ptb) hidden_states_ptb = value hidden_states_ptb = hidden_states_ptb.to(query.dtype) - + # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py index 6a6ea6b13f85..a930e3b2a8f7 100644 --- a/src/diffusers/pipelines/pag_utils.py +++ b/src/diffusers/pipelines/pag_utils.py @@ -13,7 +13,11 @@ # limitations under the License. from typing import Tuple -from ..models.attention_processor import PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, AttnProcessor2_0 +from ..models.attention_processor import ( + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, + AttnProcessor2_0, +) class PAGMixin: @@ -24,8 +28,8 @@ def enable_pag( pag_scale: float = 0.0, pag_adaptive_scaling: float = 0.0, pag_drop_rate: float = 0.5, - pag_applied_layers: Tuple[str] = ('mid',), #('down', 'mid', 'up',) - pag_applied_layers_index: Tuple[str] = None, #('d4', 'd5', 'm0',) + pag_applied_layers: Tuple[str] = ("mid",), # ('down', 'mid', 'up',) + pag_applied_layers_index: Tuple[str] = None, # ('d4', 'd5', 'm0',) pag_cfg: bool = True, ): """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. @@ -44,19 +48,19 @@ def enable_pag( self._pag_cfg = pag_cfg self._set_pag_attn_processor() - + def _get_self_attn_layers(self): down_layers = [] mid_layers = [] up_layers = [] for name, module in self.unet.named_modules(): - if 'attn1' in name and 'to' not in name: - layer_type = name.split('.')[0].split('_')[0] - if layer_type == 'down': + if "attn1" in name and "to" not in name: + layer_type = name.split(".")[0].split("_")[0] + if layer_type == "down": down_layers.append(module) - elif layer_type == 'mid': + elif layer_type == "mid": mid_layers.append(module) - elif layer_type == 'up': + elif layer_type == "up": up_layers.append(module) else: raise ValueError(f"Invalid layer type: {layer_type}") @@ -64,22 +68,22 @@ def _get_self_attn_layers(self): def _set_pag_attn_processor(self): up_layers, mid_layers, down_layers = self._get_self_attn_layers() - + if self._pag_cfg: replace_processor = PAGCFGIdentitySelfAttnProcessor2_0() else: replace_processor = PAGIdentitySelfAttnProcessor2_0() - if(self.pag_applied_layers_index): + if self.pag_applied_layers_index: drop_layers = self.pag_applied_layers_index for drop_layer in drop_layers: layer_number = int(drop_layer[1:]) try: - if drop_layer[0] == 'd': + if drop_layer[0] == "d": down_layers[layer_number].processor = replace_processor - elif drop_layer[0] == 'm': + elif drop_layer[0] == "m": mid_layers[layer_number].processor = replace_processor - elif drop_layer[0] == 'u': + elif drop_layer[0] == "u": up_layers[layer_number].processor = replace_processor else: raise ValueError(f"Invalid layer type: {drop_layer[0]}") @@ -87,7 +91,7 @@ def _set_pag_attn_processor(self): raise ValueError( f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." ) - elif(self.pag_applied_layers): + elif self.pag_applied_layers: drop_full_layers = self.pag_applied_layers for drop_full_layer in drop_full_layers: try: @@ -106,21 +110,25 @@ def _set_pag_attn_processor(self): raise ValueError( f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" ) - + def _get_pag_scale(self, t): if self.do_pag_adaptive_scaling: - signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) - if signal_scale<0: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t) + if signal_scale < 0: signal_scale = 0 return signal_scale else: return self.pag_scale - + def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): pag_scale = self._get_pag_scale(t) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + pag_scale * (noise_pred_text - noise_pred_perturb) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) else: noise_pred_uncond, noise_pred_perturb = noise_pred.chunk(2) noise_pred = noise_pred_uncond + pag_scale * (noise_pred_uncond - noise_pred_perturb) @@ -130,18 +138,18 @@ def disable_pag(self): """Disables the PAG mechanism if enabled.""" if not self.do_perturbed_attention_guidance: raise ValueError("PAG is not enabled.") - + up_layers, mid_layers, down_layers = self._get_self_attn_layers() - if(self.pag_applied_layers_index): + if self.pag_applied_layers_index: drop_layers = self.pag_applied_layers_index for drop_layer in drop_layers: layer_number = int(drop_layer[1:]) try: - if drop_layer[0] == 'd': + if drop_layer[0] == "d": down_layers[layer_number].processor = AttnProcessor2_0() - elif drop_layer[0] == 'm': + elif drop_layer[0] == "m": mid_layers[layer_number].processor = AttnProcessor2_0() - elif drop_layer[0] == 'u': + elif drop_layer[0] == "u": up_layers[layer_number].processor = AttnProcessor2_0() else: raise ValueError(f"Invalid layer type: {drop_layer[0]}") @@ -149,25 +157,25 @@ def disable_pag(self): raise ValueError( f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." ) - elif(self.pag_applied_layers): - drop_full_layers = self.pag_applied_layers - for drop_full_layer in drop_full_layers: - try: - if drop_full_layer == "down": - for down_layer in down_layers: - down_layer.processor = AttnProcessor2_0() - elif drop_full_layer == "mid": - for mid_layer in mid_layers: - mid_layer.processor = AttnProcessor2_0() - elif drop_full_layer == "up": - for up_layer in up_layers: - up_layer.processor = AttnProcessor2_0() - else: - raise ValueError(f"Invalid layer type: {drop_full_layer}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" - ) + elif self.pag_applied_layers: + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) self._pag_scale = None self._pag_adaptive_scaling = None self._pag_drop_rate = None @@ -178,27 +186,27 @@ def disable_pag(self): @property def pag_scale(self): return self._pag_scale - + @property def pag_cfg(self): return self._pag_cfg - + @property def pag_adaptive_scaling(self): return self._pag_adaptive_scaling - + @property def do_pag_adaptive_scaling(self): return self._pag_adaptive_scaling > 0 - + @property def pag_drop_rate(self): return self._pag_drop_rate - + @property def pag_applied_layers(self): return self._pag_applied_layers - + @property def pag_applied_layers_index(self): return self._pag_applied_layers_index @@ -206,4 +214,3 @@ def pag_applied_layers_index(self): @property def do_perturbed_attention_guidance(self): return hasattr(self, "_pag_scale") and self._pag_scale is not None and self._pag_scale > 0 - diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index d929574769a8..a1b55c72b225 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -701,10 +701,10 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - + if hasattr(self, "_pag_cfg") and self.pag_cfg is False and guidance_scale != 0: raise ValueError( - F"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." + f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents @@ -1067,7 +1067,7 @@ def __call__( self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False - + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1150,17 +1150,17 @@ def __call__( else: negative_add_time_ids = add_time_ids - #cfg + # cfg if self.do_perturbed_attention_guidance: prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -1200,7 +1200,7 @@ def __call__( timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) - + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1208,7 +1208,7 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] //latents.shape[0])) + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1227,15 +1227,16 @@ def __call__( )[0] # perform guidance - + if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance(noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t) + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) @@ -1322,4 +1323,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file + return StableDiffusionXLPipelineOutput(images=image) From 91d0a5b34518bf37c3b32a8bf146f5cc5c46d85f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 16:27:11 +0000 Subject: [PATCH 06/44] style --- src/diffusers/pipelines/pag_utils.py | 3 ++- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py index a930e3b2a8f7..0765dc630de7 100644 --- a/src/diffusers/pipelines/pag_utils.py +++ b/src/diffusers/pipelines/pag_utils.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Tuple + from ..models.attention_processor import ( + AttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, - AttnProcessor2_0, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a1b55c72b225..54d1e9bf18d3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -52,9 +52,9 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput -from ..pag_utils import PAGMixin if is_invisible_watermark_available(): From 01585ab3d92c5465579067b86440dcea61dfb816 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 20:44:13 +0200 Subject: [PATCH 07/44] update --- src/diffusers/pipelines/pag_utils.py | 42 +++++-------------- .../pipeline_stable_diffusion_xl.py | 2 +- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py index a930e3b2a8f7..dcf64b921bf7 100644 --- a/src/diffusers/pipelines/pag_utils.py +++ b/src/diffusers/pipelines/pag_utils.py @@ -27,7 +27,6 @@ def enable_pag( self, pag_scale: float = 0.0, pag_adaptive_scaling: float = 0.0, - pag_drop_rate: float = 0.5, pag_applied_layers: Tuple[str] = ("mid",), # ('down', 'mid', 'up',) pag_applied_layers_index: Tuple[str] = None, # ('d4', 'd5', 'm0',) pag_cfg: bool = True, @@ -42,7 +41,6 @@ def enable_pag( """ self._pag_scale = pag_scale self._pag_adaptive_scaling = pag_adaptive_scaling - self._pag_drop_rate = pag_drop_rate self._pag_applied_layers = pag_applied_layers self._pag_applied_layers_index = pag_applied_layers_index self._pag_cfg = pag_cfg @@ -74,8 +72,8 @@ def _set_pag_attn_processor(self): else: replace_processor = PAGIdentitySelfAttnProcessor2_0() - if self.pag_applied_layers_index: - drop_layers = self.pag_applied_layers_index + if self._pag_applied_layers_index: + drop_layers = self._pag_applied_layers_index for drop_layer in drop_layers: layer_number = int(drop_layer[1:]) try: @@ -91,8 +89,8 @@ def _set_pag_attn_processor(self): raise ValueError( f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." ) - elif self.pag_applied_layers: - drop_full_layers = self.pag_applied_layers + elif self._pag_applied_layers: + drop_full_layers = self._pag_applied_layers for drop_full_layer in drop_full_layers: try: if drop_full_layer == "down": @@ -113,12 +111,12 @@ def _set_pag_attn_processor(self): def _get_pag_scale(self, t): if self.do_pag_adaptive_scaling: - signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t) + signal_scale = self._pag_scale - self.pag_adaptive_scaling * (1000 - t) if signal_scale < 0: signal_scale = 0 return signal_scale else: - return self.pag_scale + return self._pag_scale def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): pag_scale = self._get_pag_scale(t) @@ -140,8 +138,8 @@ def disable_pag(self): raise ValueError("PAG is not enabled.") up_layers, mid_layers, down_layers = self._get_self_attn_layers() - if self.pag_applied_layers_index: - drop_layers = self.pag_applied_layers_index + if self._pag_applied_layers_index: + drop_layers = self._pag_applied_layers_index for drop_layer in drop_layers: layer_number = int(drop_layer[1:]) try: @@ -157,8 +155,8 @@ def disable_pag(self): raise ValueError( f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." ) - elif self.pag_applied_layers: - drop_full_layers = self.pag_applied_layers + elif self._pag_applied_layers: + drop_full_layers = self._pag_applied_layers for drop_full_layer in drop_full_layers: try: if drop_full_layer == "down": @@ -183,14 +181,6 @@ def disable_pag(self): self._pag_applied_layers_index = None self._pag_cfg = None - @property - def pag_scale(self): - return self._pag_scale - - @property - def pag_cfg(self): - return self._pag_cfg - @property def pag_adaptive_scaling(self): return self._pag_adaptive_scaling @@ -199,18 +189,6 @@ def pag_adaptive_scaling(self): def do_pag_adaptive_scaling(self): return self._pag_adaptive_scaling > 0 - @property - def pag_drop_rate(self): - return self._pag_drop_rate - - @property - def pag_applied_layers(self): - return self._pag_applied_layers - - @property - def pag_applied_layers_index(self): - return self._pag_applied_layers_index - @property def do_perturbed_attention_guidance(self): return hasattr(self, "_pag_scale") and self._pag_scale is not None and self._pag_scale > 0 diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a1b55c72b225..21bfe4d27ecc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -702,7 +702,7 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - if hasattr(self, "_pag_cfg") and self.pag_cfg is False and guidance_scale != 0: + if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: raise ValueError( f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." ) From 03bdbcdb075e91561544d785d6e80ab79487e2ea Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 21:07:24 +0200 Subject: [PATCH 08/44] inpaint + controlnet --- .../controlnet/pipeline_controlnet_sd_xl.py | 24 +++++++++++++++--- .../pipeline_stable_diffusion_xl.py | 2 -- .../pipeline_stable_diffusion_xl_inpaint.py | 25 ++++++++++++++++--- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 763188f34735..63c55a607d62 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -55,6 +55,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -181,6 +182,7 @@ class StableDiffusionXLControlNetPipeline( StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, + PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -619,6 +621,7 @@ def check_inputs( prompt_2, image, callback_steps, + guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -802,6 +805,11 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) + if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: + raise ValueError( + f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." + ) + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) @@ -1223,6 +1231,7 @@ def __call__( prompt_2, image, callback_steps, + guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1405,6 +1414,11 @@ def __call__( else: negative_add_time_ids = add_time_ids + if self.do_perturbed_attention_guidance: + prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1442,8 +1456,8 @@ def __call__( # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # expand the latents if we are doing classifier free guidance or perturbed attention guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -1506,7 +1520,11 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance: + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 21bfe4d27ecc..767362b74c5e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1150,7 +1150,6 @@ def __call__( else: negative_add_time_ids = add_time_ids - # cfg if self.do_perturbed_attention_guidance: prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) @@ -1227,7 +1226,6 @@ def __call__( )[0] # perform guidance - if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 38f5cec931f8..56093ed64cab 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -53,6 +53,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput @@ -330,6 +331,7 @@ class StableDiffusionXLInpaintPipeline( StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, + PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -785,6 +787,7 @@ def check_inputs( strength, callback_steps, output_type, + guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -877,6 +880,11 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) + + if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: + raise ValueError( + f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." + ) def prepare_latents( self, @@ -1460,6 +1468,7 @@ def __call__( strength, callback_steps, output_type, + guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1659,6 +1668,11 @@ def denoising_value_valid(dnv): ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + if self.do_perturbed_attention_guidance: + prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1715,8 +1729,8 @@ def denoising_value_valid(dnv): for i, t in enumerate(timesteps): if self.interrupt: continue - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # expand the latents if we are doing classifier free guidance or perturbed attention guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1739,7 +1753,12 @@ def denoising_value_valid(dnv): )[0] # perform guidance - if self.do_classifier_free_guidance: + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) From 1fb2c3354ede24f64f44dda828e0b5680e82dd8e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 19:08:56 +0000 Subject: [PATCH 09/44] style --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 56093ed64cab..7bb2e346e66e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -880,7 +880,7 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - + if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: raise ValueError( f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." From 219f4b98a6923cc802262814924a3349479942d8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 14 May 2024 19:32:09 +0000 Subject: [PATCH 10/44] up --- .../pipeline_stable_diffusion_xl.py | 3 +-- .../pipeline_stable_diffusion_xl_adapter.py | 24 ++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 85c2f6e51d9c..b28ebc334e44 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -559,8 +559,7 @@ def prepare_ip_adapter_image_embeds( single_negative_image_embeds = torch.stack( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) - if self.do_perturbed_attention_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + if do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 2aa2415a47bb..c921c02ba9f2 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -51,6 +51,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -201,6 +202,7 @@ class StableDiffusionXLAdapterPipeline( StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, + PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter @@ -624,6 +626,7 @@ def check_inputs( height, width, callback_steps, + guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -713,6 +716,11 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) + if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: + raise ValueError( + f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( @@ -1067,6 +1075,7 @@ def __call__( height, width, callback_steps, + guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1189,6 +1198,11 @@ def __call__( else: negative_add_time_ids = add_time_ids + if self.do_perturbed_attention_guidance: + prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1213,8 +1227,8 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # expand the latents if we are doing classifier free guidance or perturbed attention guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1241,7 +1255,11 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance: + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) From 5641cb474e760ddd8fcda44856fce46c7b3ba72f Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 14 May 2024 09:39:12 -1000 Subject: [PATCH 11/44] Update src/diffusers/pipelines/pag_utils.py --- src/diffusers/pipelines/pag_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py index 1f2fb26993fb..bb1b1083235f 100644 --- a/src/diffusers/pipelines/pag_utils.py +++ b/src/diffusers/pipelines/pag_utils.py @@ -177,7 +177,6 @@ def disable_pag(self): ) self._pag_scale = None self._pag_adaptive_scaling = None - self._pag_drop_rate = None self._pag_applied_layers = None self._pag_applied_layers_index = None self._pag_cfg = None From 8950e80192dfff75ed7044503d202643e251f454 Mon Sep 17 00:00:00 2001 From: Junhwa Song Date: Wed, 15 May 2024 23:25:57 +0000 Subject: [PATCH 12/44] fix controlnet --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 63c55a607d62..6d804941be54 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -848,7 +848,6 @@ def check_image(self, image, prompt, prompt_embeds): f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, image, @@ -859,6 +858,7 @@ def prepare_image( device, dtype, do_classifier_free_guidance=False, + do_perturbed_attention_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) @@ -873,9 +873,13 @@ def prepare_image( image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: + + if do_classifier_free_guidance and not do_perturbed_attention_guidance and not guess_mode: + image = torch.cat([image] * 2) + elif not do_classifier_free_guidance and do_perturbed_attention_guidance and not guess_mode: image = torch.cat([image] * 2) + elif do_classifier_free_guidance and do_perturbed_attention_guidance and not guess_mode: + image = torch.cat([image] * 3) return image @@ -1317,6 +1321,7 @@ def __call__( device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, + do_perturbed_attention_guidance=self.do_perturbed_attention_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] From 4cc0b8b06bc9933641cb26926a844139d76c1dc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahn=20Donghoon=20=28=EC=95=88=EB=8F=99=ED=9B=88=20/=20suno?= =?UTF-8?q?=29?= Date: Wed, 5 Jun 2024 16:36:49 +0900 Subject: [PATCH 13/44] fix compatability issue between PAG and IP-adapter (#8379) * fix compatability issue between PAG and IP-adapter * fix compatibility issue between PAG and IP-adapter plus --- src/diffusers/loaders/unet.py | 4 +--- src/diffusers/pipelines/pag_utils.py | 5 +++-- .../pipeline_stable_diffusion_xl.py | 16 +++++++++++++--- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7db7bfeda600..018b57dd07eb 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -928,9 +928,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F hidden_size = self.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: - attn_processor_class = ( - AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor - ) + attn_processor_class = self.attn_processors[name].__class__ attn_procs[name] = attn_processor_class() else: diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py index bb1b1083235f..c9da87523359 100644 --- a/src/diffusers/pipelines/pag_utils.py +++ b/src/diffusers/pipelines/pag_utils.py @@ -45,7 +45,7 @@ def enable_pag( self._pag_applied_layers = pag_applied_layers self._pag_applied_layers_index = pag_applied_layers_index self._pag_cfg = pag_cfg - + self._is_pag_enabled = True self._set_pag_attn_processor() def _get_self_attn_layers(self): @@ -180,6 +180,7 @@ def disable_pag(self): self._pag_applied_layers = None self._pag_applied_layers_index = None self._pag_cfg = None + self._is_pag_enabled = False @property def pag_adaptive_scaling(self): @@ -191,4 +192,4 @@ def do_pag_adaptive_scaling(self): @property def do_perturbed_attention_guidance(self): - return hasattr(self, "_pag_scale") and self._pag_scale is not None and self._pag_scale > 0 + return self._is_pag_enabled diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index b28ebc334e44..aebf3f0d941e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -536,7 +536,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance, do_perturbed_attention_guidance ): if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): @@ -560,6 +560,10 @@ def prepare_ip_adapter_image_embeds( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) + if do_perturbed_attention_guidance: + single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device) + if do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) @@ -577,11 +581,16 @@ def prepare_ip_adapter_image_embeds( single_negative_image_embeds = single_negative_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + if do_perturbed_attention_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds, single_image_embeds], dim=0) + else: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: single_image_embeds = single_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) ) + if do_perturbed_attention_guidance: + single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0) image_embeds.append(single_image_embeds) return image_embeds @@ -1170,6 +1179,7 @@ def __call__( device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, + self.do_perturbed_attention_guidance, ) # 8. Denoising loop @@ -1205,7 +1215,7 @@ def __call__( if self.interrupt: continue - # expand the latents if we are doing classifier free guidance + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) From 5cbf22632a84371b8b6d04ebe903fefd28b0e459 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 6 Jun 2024 12:13:41 +0200 Subject: [PATCH 14/44] up --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 7 + src/diffusers/pipelines/auto_pipeline.py | 16 + .../controlnet/pipeline_controlnet_sd_xl.py | 35 +- src/diffusers/pipelines/pag/__init__.py | 48 + src/diffusers/pipelines/pag/pag_utils.py | 167 +++ .../pipelines/pag/pipeline_pag_sd_xl.py | 1311 +++++++++++++++++ src/diffusers/pipelines/pag_utils.py | 195 --- .../pipeline_stable_diffusion_xl.py | 39 +- .../pipeline_stable_diffusion_xl_inpaint.py | 25 +- .../pipeline_stable_diffusion_xl_adapter.py | 24 +- 11 files changed, 1568 insertions(+), 301 deletions(-) create mode 100644 src/diffusers/pipelines/pag/__init__.py create mode 100644 src/diffusers/pipelines/pag/pag_utils.py create mode 100644 src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py delete mode 100644 src/diffusers/pipelines/pag_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 66c98804eadc..0868dd966f7d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -299,6 +299,7 @@ "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLPAGPipeline", "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", @@ -681,6 +682,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c2dd7ac0d551..ee51277a4f05 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -26,6 +26,7 @@ "ledits_pp": [], "stable_diffusion": [], "stable_diffusion_xl": [], + "pag": [], } try: @@ -135,6 +136,11 @@ "StableDiffusionXLControlNetPipeline", ] ) + _import_structure["pag"].extend( + [ + "StableDiffusionXLPAGPipeline", + ] + ) _import_structure["controlnet_xs"].extend( [ "StableDiffusionControlNetXSPipeline", @@ -389,6 +395,7 @@ StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline, ) + from .pag import StableDiffusionXLPAGPipeline from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 5fb497ef2e22..c43d0518f8a0 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -26,6 +26,7 @@ StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) +from .pag import StableDiffusionXLPAGPipeline from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -75,6 +76,7 @@ ("lcm", LatentConsistencyModelPipeline), ("pixart-alpha", PixArtAlphaPipeline), ("pixart-sigma", PixArtSigmaPipeline), + ("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline), ] ) @@ -332,6 +334,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if "controlnet" in kwargs: orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + if "enable_pag" in kwargs: + orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -384,6 +388,18 @@ def from_pipe(cls, pipeline, **kwargs): AUTO_TEXT2IMAGE_PIPELINES_MAPPING, text_2_image_cls.__name__.replace("ControlNetPipeline", "Pipeline"), ) + + if "enable_pag" in kwargs: + if kwargs["enable_pag"] is not None: + text_2_image_cls = _get_task_class( + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + text_2_image_cls.__name__.replace("PAG", "").replace("Pipeline", "PAGPipeline"), + ) + else: + text_2_image_cls = _get_task_class( + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + text_2_image_cls.__name__.replace("PAGPipeline", "Pipeline"), + ) # define expected module and optional kwargs given the pipeline signature expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 6d804941be54..763188f34735 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -55,7 +55,6 @@ unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor -from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -182,7 +181,6 @@ class StableDiffusionXLControlNetPipeline( StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, - PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -621,7 +619,6 @@ def check_inputs( prompt_2, image, callback_steps, - guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -805,11 +802,6 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: - raise ValueError( - f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." - ) - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) @@ -848,6 +840,7 @@ def check_image(self, image, prompt, prompt_embeds): f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, image, @@ -858,7 +851,6 @@ def prepare_image( device, dtype, do_classifier_free_guidance=False, - do_perturbed_attention_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) @@ -873,13 +865,9 @@ def prepare_image( image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not do_perturbed_attention_guidance and not guess_mode: - image = torch.cat([image] * 2) - elif not do_classifier_free_guidance and do_perturbed_attention_guidance and not guess_mode: + + if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) - elif do_classifier_free_guidance and do_perturbed_attention_guidance and not guess_mode: - image = torch.cat([image] * 3) return image @@ -1235,7 +1223,6 @@ def __call__( prompt_2, image, callback_steps, - guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1321,7 +1308,6 @@ def __call__( device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, - do_perturbed_attention_guidance=self.do_perturbed_attention_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] @@ -1419,11 +1405,6 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_perturbed_attention_guidance: - prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1461,8 +1442,8 @@ def __call__( # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance or perturbed attention guidance - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -1525,11 +1506,7 @@ def __call__( )[0] # perform guidance - if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t - ) - elif self.do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py new file mode 100644 index 000000000000..976b981575c2 --- /dev/null +++ b/src/diffusers/pipelines/pag/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py new file mode 100644 index 000000000000..4b4927c32de4 --- /dev/null +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -0,0 +1,167 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple +import torch + +from ...models.attention_processor import ( + AttnProcessor2_0, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, +) + +from ...utils import logging +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class PAGMixin: + r"""Mixin class for PAG.""" + + @staticmethod + def _check_input_pag_applied_layer(layer): + layer_splits = layer.split(".") + if len(layer_splits) > 3: + raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") + + if len(layer_splits) >= 1: + if layer_splits[0] not in ["down", "mid", "up"]: + raise ValueError(f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}") + + if len(layer_splits) >= 2: + if not layer_splits[1].startswith("block_"): + raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'") + + if len(layer_splits) == 3: + if not layer_splits[2].startswith("attentions_"): + raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'") + + + def _set_attn_processor_pag_applied_layers(self, replace_processor): + + def is_self_attn(name): + return "attn1" in name and "to" not in name + + def get_block_type(name): + # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" + return name.split(".")[0].split("_")[0] + + def get_block_index(name): + # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "blocks_1" + return f"block_{name.split('.')[1]}" + + def get_attn_index(name): + # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" + return f"attentions_{name.split('.')[3]}" + + + for drop_layer in self.pag_applied_layers: + + self._check_input_pag_applied_layer(drop_layer) + drop_layer_splits = drop_layer.split(".") + + if len(drop_layer_splits) == 1: + # e.g. "mid" + block_type = drop_layer_splits[0] + target_modules = [] + for name, module in self.unet.named_modules(): + if not is_self_attn(name): + continue + if get_block_type(name) == block_type: + target_modules.append(module) + + elif len(drop_layer_splits) == 2: + # e.g. "down.block_1" + block_type = drop_layer_splits[0] + block_index = drop_layer_splits[1] + target_modules = [] + for name, module in self.unet.named_modules(): + if not is_self_attn(name): + continue + if get_block_type(name) == block_type and get_block_index(name) == block_index: + target_modules.append(module) + + elif len(drop_layer_splits) == 3: + # e.g. "down.blocks_1.attentions_1" + block_type = drop_layer_splits[0] + block_index = drop_layer_splits[1] + attn_index = drop_layer_splits[2] + target_modules = [] + for name, module in self.unet.named_modules(): + if not is_self_attn(name): + continue + if get_block_type(name) == block_type and get_block_index(name) == block_index and get_attn_index(name) == attn_index: + target_modules.append(module) + + if len(target_modules) == 0: + logger.warning(f"Cannot find pag layer to set attention processor: {drop_layer}") + + for module in target_modules: + module.processor = replace_processor + + def _set_pag_attn_processor(self, do_classifier_free_guidance): + if do_classifier_free_guidance: + self._set_attn_processor_pag_applied_layers(PAGCFGIdentitySelfAttnProcessor2_0()) + else: + self._set_attn_processor_pag_applied_layers(PAGIdentitySelfAttnProcessor2_0()) + + def _reset_attn_processor(self): + self._set_attn_processor_pag_applied_layers(AttnProcessor2_0()) + + def _get_pag_scale(self, t): + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) + if signal_scale < 0: + signal_scale = 0 + return signal_scale + else: + return self.pag_scale + + def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + pag_scale = self._get_pag_scale(t) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) + else: + noise_pred_uncond, noise_pred_perturb = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + pag_scale * (noise_pred_uncond - noise_pred_perturb) + return noise_pred + + def _prepare_perturbed_attention_guidance(self, input, uncond_input, do_classifier_free_guidance): + input = torch.cat([input] * 2, dim=0) + + if do_classifier_free_guidance: + input = torch.cat([uncond_input, input], dim=0) + return input + + def _reset_attn_processor(self): + self._set_attn_processor_pag_applied_layers(AttnProcessor2_0()) + + @property + def pag_scale(self): + return self._pag_scale + + @property + def pag_adaptive_scale(self): + return self._pag_adaptive_scale + + @property + def do_pag_adaptive_scaling(self): + return self._is_pag_enabled and self._pag_adaptive_scale > 0 + + @property + def do_perturbed_attention_guidance(self): + return self._is_pag_enabled and self._pag_scale > 0 diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py new file mode 100644 index 000000000000..4a8e38f164bf --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -0,0 +1,1311 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from .pag_utils import PAGMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + enable_pag: bool = False, + pag_applied_layers: Union[str,List[str]] = "mid", # ["mid"],["down.1"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + if enable_pag and pag_applied_layers is not None: + self._is_pag_enabled = True + else: + self._is_pag_enabled = False + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + self.pag_applied_layers = pag_applied_layers + + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance, do_perturbed_attention_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.stack([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + + if do_perturbed_attention_guidance: + single_image_embeds = self._prepare_perturbed_attention_guidance(single_image_embeds, single_negative_image_embeds, do_classifier_free_guidance) + elif do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 0.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance(prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance) + add_text_embeds = self._prepare_perturbed_attention_guidance(add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance) + add_time_ids = self._prepare_perturbed_attention_guidance(add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + self.do_perturbed_attention_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + self._set_pag_attn_processor(self.do_classifier_free_guidance) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self._reset_attn_processor() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/pag_utils.py b/src/diffusers/pipelines/pag_utils.py deleted file mode 100644 index c9da87523359..000000000000 --- a/src/diffusers/pipelines/pag_utils.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -from ..models.attention_processor import ( - AttnProcessor2_0, - PAGCFGIdentitySelfAttnProcessor2_0, - PAGIdentitySelfAttnProcessor2_0, -) - - -class PAGMixin: - r"""Mixin class for PAG.""" - - def enable_pag( - self, - pag_scale: float = 0.0, - pag_adaptive_scaling: float = 0.0, - pag_applied_layers: Tuple[str] = ("mid",), # ('down', 'mid', 'up',) - pag_applied_layers_index: Tuple[str] = None, # ('d4', 'd5', 'm0',) - pag_cfg: bool = True, - ): - """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. - - This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). - - Args: - pag_scale (`float`, *optional*, defaults to `0.0`): - Guidance scale of PAG. - """ - self._pag_scale = pag_scale - self._pag_adaptive_scaling = pag_adaptive_scaling - self._pag_applied_layers = pag_applied_layers - self._pag_applied_layers_index = pag_applied_layers_index - self._pag_cfg = pag_cfg - self._is_pag_enabled = True - self._set_pag_attn_processor() - - def _get_self_attn_layers(self): - down_layers = [] - mid_layers = [] - up_layers = [] - for name, module in self.unet.named_modules(): - if "attn1" in name and "to" not in name: - layer_type = name.split(".")[0].split("_")[0] - if layer_type == "down": - down_layers.append(module) - elif layer_type == "mid": - mid_layers.append(module) - elif layer_type == "up": - up_layers.append(module) - else: - raise ValueError(f"Invalid layer type: {layer_type}") - return up_layers, mid_layers, down_layers - - def _set_pag_attn_processor(self): - up_layers, mid_layers, down_layers = self._get_self_attn_layers() - - if self._pag_cfg: - replace_processor = PAGCFGIdentitySelfAttnProcessor2_0() - else: - replace_processor = PAGIdentitySelfAttnProcessor2_0() - - if self._pag_applied_layers_index: - drop_layers = self._pag_applied_layers_index - for drop_layer in drop_layers: - layer_number = int(drop_layer[1:]) - try: - if drop_layer[0] == "d": - down_layers[layer_number].processor = replace_processor - elif drop_layer[0] == "m": - mid_layers[layer_number].processor = replace_processor - elif drop_layer[0] == "u": - up_layers[layer_number].processor = replace_processor - else: - raise ValueError(f"Invalid layer type: {drop_layer[0]}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." - ) - elif self._pag_applied_layers: - drop_full_layers = self._pag_applied_layers - for drop_full_layer in drop_full_layers: - try: - if drop_full_layer == "down": - for down_layer in down_layers: - down_layer.processor = replace_processor - elif drop_full_layer == "mid": - for mid_layer in mid_layers: - mid_layer.processor = replace_processor - elif drop_full_layer == "up": - for up_layer in up_layers: - up_layer.processor = replace_processor - else: - raise ValueError(f"Invalid layer type: {drop_full_layer}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" - ) - - def _get_pag_scale(self, t): - if self.do_pag_adaptive_scaling: - signal_scale = self._pag_scale - self.pag_adaptive_scaling * (1000 - t) - if signal_scale < 0: - signal_scale = 0 - return signal_scale - else: - return self._pag_scale - - def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): - pag_scale = self._get_pag_scale(t) - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_uncond, noise_pred_perturb = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + pag_scale * (noise_pred_uncond - noise_pred_perturb) - return noise_pred - - def disable_pag(self): - """Disables the PAG mechanism if enabled.""" - if not self.do_perturbed_attention_guidance: - raise ValueError("PAG is not enabled.") - - up_layers, mid_layers, down_layers = self._get_self_attn_layers() - if self._pag_applied_layers_index: - drop_layers = self._pag_applied_layers_index - for drop_layer in drop_layers: - layer_number = int(drop_layer[1:]) - try: - if drop_layer[0] == "d": - down_layers[layer_number].processor = AttnProcessor2_0() - elif drop_layer[0] == "m": - mid_layers[layer_number].processor = AttnProcessor2_0() - elif drop_layer[0] == "u": - up_layers[layer_number].processor = AttnProcessor2_0() - else: - raise ValueError(f"Invalid layer type: {drop_layer[0]}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." - ) - elif self._pag_applied_layers: - drop_full_layers = self._pag_applied_layers - for drop_full_layer in drop_full_layers: - try: - if drop_full_layer == "down": - for down_layer in down_layers: - down_layer.processor = AttnProcessor2_0() - elif drop_full_layer == "mid": - for mid_layer in mid_layers: - mid_layer.processor = AttnProcessor2_0() - elif drop_full_layer == "up": - for up_layer in up_layers: - up_layer.processor = AttnProcessor2_0() - else: - raise ValueError(f"Invalid layer type: {drop_full_layer}") - except IndexError: - raise ValueError( - f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" - ) - self._pag_scale = None - self._pag_adaptive_scaling = None - self._pag_applied_layers = None - self._pag_applied_layers_index = None - self._pag_cfg = None - self._is_pag_enabled = False - - @property - def pag_adaptive_scaling(self): - return self._pag_adaptive_scaling - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scaling > 0 - - @property - def do_perturbed_attention_guidance(self): - return self._is_pag_enabled diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index aebf3f0d941e..52d0b07fb315 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -52,7 +52,6 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor -from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput @@ -169,7 +168,6 @@ class StableDiffusionXLPipeline( StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, - PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -536,7 +534,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance, do_perturbed_attention_guidance + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): @@ -560,10 +558,6 @@ def prepare_ip_adapter_image_embeds( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) - if do_perturbed_attention_guidance: - single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0) - single_image_embeds = single_image_embeds.to(device) - if do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) @@ -581,16 +575,11 @@ def prepare_ip_adapter_image_embeds( single_negative_image_embeds = single_negative_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) ) - if do_perturbed_attention_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds, single_image_embeds], dim=0) - else: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: single_image_embeds = single_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) ) - if do_perturbed_attention_guidance: - single_image_embeds = torch.cat([single_image_embeds, single_image_embeds], dim=0) image_embeds.append(single_image_embeds) return image_embeds @@ -620,7 +609,6 @@ def check_inputs( height, width, callback_steps, - guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -710,11 +698,6 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: - raise ValueError( - f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." - ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( @@ -1057,7 +1040,6 @@ def __call__( height, width, callback_steps, - guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1158,11 +1140,6 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_perturbed_attention_guidance: - prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1179,7 +1156,6 @@ def __call__( device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, - self.do_perturbed_attention_guidance, ) # 8. Denoising loop @@ -1215,8 +1191,8 @@ def __call__( if self.interrupt: continue - # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1235,12 +1211,7 @@ def __call__( )[0] # perform guidance - if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t - ) - - elif self.do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 7bb2e346e66e..38f5cec931f8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -53,7 +53,6 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor -from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput @@ -331,7 +330,6 @@ class StableDiffusionXLInpaintPipeline( StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, - PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -787,7 +785,6 @@ def check_inputs( strength, callback_steps, output_type, - guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -881,11 +878,6 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: - raise ValueError( - f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." - ) - def prepare_latents( self, batch_size, @@ -1468,7 +1460,6 @@ def __call__( strength, callback_steps, output_type, - guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1668,11 +1659,6 @@ def denoising_value_valid(dnv): ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - if self.do_perturbed_attention_guidance: - prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1729,8 +1715,8 @@ def denoising_value_valid(dnv): for i, t in enumerate(timesteps): if self.interrupt: continue - # expand the latents if we are doing classifier free guidance or perturbed attention guidance - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1753,12 +1739,7 @@ def denoising_value_valid(dnv): )[0] # perform guidance - if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t - ) - - elif self.do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index c921c02ba9f2..2aa2415a47bb 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -51,7 +51,6 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor -from ..pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -202,7 +201,6 @@ class StableDiffusionXLAdapterPipeline( StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, - PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter @@ -626,7 +624,6 @@ def check_inputs( height, width, callback_steps, - guidance_scale, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -716,11 +713,6 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0: - raise ValueError( - f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0." - ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( @@ -1075,7 +1067,6 @@ def __call__( height, width, callback_steps, - guidance_scale, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1198,11 +1189,6 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_perturbed_attention_guidance: - prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) @@ -1227,8 +1213,8 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance or perturbed attention guidance - latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1255,11 +1241,7 @@ def __call__( )[0] # perform guidance - if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t - ) - elif self.do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) From 58804a0da837cda817468ce190ee367a08453599 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Jun 2024 11:19:09 +0200 Subject: [PATCH 15/44] refactor ip-adapter --- .../animatediff/pipeline_animatediff.py | 43 ++++++++----------- .../animatediff/pipeline_animatediff_sdxl.py | 43 ++++++++----------- .../pipeline_animatediff_video2video.py | 43 ++++++++----------- .../controlnet/pipeline_controlnet.py | 43 ++++++++----------- .../controlnet/pipeline_controlnet_img2img.py | 43 ++++++++----------- .../controlnet/pipeline_controlnet_inpaint.py | 43 ++++++++----------- .../pipeline_controlnet_inpaint_sd_xl.py | 43 ++++++++----------- .../controlnet/pipeline_controlnet_sd_xl.py | 43 ++++++++----------- .../pipeline_controlnet_sd_xl_img2img.py | 43 ++++++++----------- .../pipeline_latent_consistency_img2img.py | 43 ++++++++----------- .../pipeline_latent_consistency_text2img.py | 43 ++++++++----------- .../pipelines/pag/pipeline_pag_sd_xl.py | 35 ++++++++------- src/diffusers/pipelines/pia/pipeline_pia.py | 43 ++++++++----------- .../pipeline_stable_diffusion.py | 43 ++++++++----------- .../pipeline_stable_diffusion_img2img.py | 43 ++++++++----------- .../pipeline_stable_diffusion_inpaint.py | 43 ++++++++----------- .../pipeline_stable_diffusion_ldm3d.py | 43 ++++++++----------- .../pipeline_stable_diffusion_panorama.py | 43 ++++++++----------- .../pipeline_stable_diffusion_xl.py | 43 ++++++++----------- .../pipeline_stable_diffusion_xl_img2img.py | 43 ++++++++----------- .../pipeline_stable_diffusion_xl_inpaint.py | 43 ++++++++----------- .../pipeline_stable_diffusion_xl_adapter.py | 43 ++++++++----------- .../dummy_torch_and_transformers_objects.py | 15 +++++++ 23 files changed, 433 insertions(+), 520 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 175671ece5d7..b394a34049db 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -351,6 +351,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -360,7 +363,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -368,36 +370,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 50ed54001e16..1c9c9e0d4200 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -566,6 +566,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -575,7 +578,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -583,36 +585,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 2adc5cdf82e5..99ec145597fe 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -455,6 +455,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -464,7 +467,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -472,36 +474,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index e64dcdc55457..361dc7c5f262 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -497,6 +497,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -506,7 +509,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -514,36 +516,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2e44efa78b73..9cfa28527782 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -475,6 +475,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -484,7 +487,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -492,36 +494,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index cdc34819d59e..bbd2a07cb20d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -600,6 +600,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -609,7 +612,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -617,36 +619,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 3cfdefa9d44d..a7d0a01770b7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -515,6 +515,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -524,7 +527,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -532,36 +534,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 763188f34735..28e95d317f80 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -547,6 +547,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -556,7 +559,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -564,36 +566,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index dbd406d928d5..41c091028e4f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -541,6 +541,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -550,7 +553,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -558,36 +560,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 11e3781aaf7c..e76c663a3d6a 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -440,6 +440,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -449,7 +452,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -457,36 +459,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index 7f3495258402..69cc9f956a58 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -424,6 +424,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -433,7 +436,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -441,36 +443,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 4a8e38f164bf..c66c21b4ae1e 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -548,7 +548,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance, do_perturbed_attention_guidance + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: @@ -570,32 +570,26 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - image_embeds.append(single_image_embeds) + + image_embeds.append(single_image_embeds[None,:]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds) + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: - single_negative_image_embeds = torch.stack([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - - if do_perturbed_attention_guidance: - single_image_embeds = self._prepare_perturbed_attention_guidance(single_image_embeds, single_negative_image_embeds, do_classifier_free_guidance) - elif do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) - return ip_adapter_image_embeds @@ -1149,14 +1143,23 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, - self.do_perturbed_attention_guidance, ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance(image_embeds, negative_image_embeds, self.do_classifier_free_guidance) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1201,8 +1204,8 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - added_cond_kwargs["image_embeds"] = image_embeds + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 1ec418b0e867..a02501188036 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -504,6 +504,9 @@ def check_inputs( def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -513,7 +516,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -521,36 +523,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e8ab72421d7e..76125183c753 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -507,6 +507,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -516,7 +519,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -524,36 +526,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f2a5de81540d..9d34f6dc119b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -551,6 +551,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -560,7 +563,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -568,36 +570,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 71dec964fdca..1033ea7a4721 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -623,6 +623,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -632,7 +635,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -640,36 +642,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index a100a38b04ea..663b50e899f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -490,6 +490,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -499,7 +502,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -507,36 +509,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 2b80d2856869..ad9291701941 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -462,6 +462,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -471,7 +474,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -479,36 +481,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 52d0b07fb315..c43486bcb952 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -536,6 +536,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -545,7 +548,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -553,36 +555,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b8698a008320..70fc4691c60c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -780,6 +780,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -789,7 +792,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -797,36 +799,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds def _get_add_time_ids( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 38f5cec931f8..447646e0752c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -473,6 +473,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -482,7 +485,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -490,36 +492,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 2aa2415a47bb..e43af9b9c047 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -550,6 +550,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -559,7 +562,6 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) - image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): @@ -567,36 +569,29 @@ def prepare_ip_adapter_image_embeds( single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack( - [single_negative_image_embeds] * num_images_per_prompt, dim=0 - ) - if do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) - image_embeds.append(single_image_embeds) + image_embeds.append(single_image_embeds[None,:]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None,:]) else: - repeat_dims = [1] - image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) - single_negative_image_embeds = single_negative_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) - ) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - else: - single_image_embeds = single_image_embeds.repeat( - num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) - ) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - - return image_embeds + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0583cf839ff7..a594b95f0338 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1352,6 +1352,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 7bc9229f86e248ccfb965e8d6325be11829ddd8a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Jun 2024 09:22:33 +0000 Subject: [PATCH 16/44] style --- src/diffusers/__init__.py | 2 +- src/diffusers/loaders/unet.py | 2 - src/diffusers/pipelines/__init__.py | 4 +- .../animatediff/pipeline_animatediff.py | 11 ++--- .../animatediff/pipeline_animatediff_sdxl.py | 11 ++--- .../pipeline_animatediff_video2video.py | 11 ++--- src/diffusers/pipelines/auto_pipeline.py | 4 +- .../controlnet/pipeline_controlnet.py | 11 ++--- .../controlnet/pipeline_controlnet_img2img.py | 11 ++--- .../controlnet/pipeline_controlnet_inpaint.py | 11 ++--- .../pipeline_controlnet_inpaint_sd_xl.py | 11 ++--- .../controlnet/pipeline_controlnet_sd_xl.py | 11 ++--- .../pipeline_controlnet_sd_xl_img2img.py | 11 ++--- .../pipeline_latent_consistency_img2img.py | 11 ++--- .../pipeline_latent_consistency_text2img.py | 11 ++--- src/diffusers/pipelines/pag/__init__.py | 2 +- src/diffusers/pipelines/pag/pag_utils.py | 47 ++++++++++--------- .../pipelines/pag/pipeline_pag_sd_xl.py | 43 +++++++++-------- src/diffusers/pipelines/pia/pipeline_pia.py | 11 ++--- .../pipeline_stable_diffusion.py | 11 ++--- .../pipeline_stable_diffusion_img2img.py | 11 ++--- .../pipeline_stable_diffusion_inpaint.py | 11 ++--- .../pipeline_stable_diffusion_ldm3d.py | 11 ++--- .../pipeline_stable_diffusion_panorama.py | 11 ++--- .../pipeline_stable_diffusion_xl.py | 11 ++--- .../pipeline_stable_diffusion_xl_img2img.py | 11 ++--- .../pipeline_stable_diffusion_xl_inpaint.py | 11 ++--- .../pipeline_stable_diffusion_xl_adapter.py | 11 ++--- 28 files changed, 160 insertions(+), 175 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0868dd966f7d..f16c10e617e2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -299,11 +299,11 @@ "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLPAGPipeline", "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 018b57dd07eb..a8d324b77110 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -887,8 +887,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): from ..models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ee51277a4f05..5950ef21e615 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -24,9 +24,9 @@ "deprecated": [], "latent_diffusion": [], "ledits_pp": [], + "pag": [], "stable_diffusion": [], "stable_diffusion_xl": [], - "pag": [], } try: @@ -395,7 +395,6 @@ StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline, ) - from .pag import StableDiffusionXLPAGPipeline from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, @@ -456,6 +455,7 @@ LEditsPPPipelineStableDiffusionXL, ) from .musicldm import MusicLDMPipeline + from .pag import StableDiffusionXLPAGPipeline from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b394a34049db..33c00094a6ec 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -371,24 +371,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 1c9c9e0d4200..2ee2f65af232 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -586,24 +586,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 99ec145597fe..4f4953767ed0 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -475,24 +475,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index c43d0518f8a0..1c2de3114b2c 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -26,7 +26,6 @@ StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) -from .pag import StableDiffusionXLPAGPipeline from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -46,6 +45,7 @@ ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline +from .pag import StableDiffusionXLPAGPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline from .stable_diffusion import ( @@ -388,7 +388,7 @@ def from_pipe(cls, pipeline, **kwargs): AUTO_TEXT2IMAGE_PIPELINES_MAPPING, text_2_image_cls.__name__.replace("ControlNetPipeline", "Pipeline"), ) - + if "enable_pag" in kwargs: if kwargs["enable_pag"] is not None: text_2_image_cls = _get_task_class( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 361dc7c5f262..016d1c26bae1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -517,24 +517,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 9cfa28527782..f79cab113e5e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -495,24 +495,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index bbd2a07cb20d..8e5f39d9b72d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -620,24 +620,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index a7d0a01770b7..871558734531 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -535,24 +535,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 28e95d317f80..c3b934529cd1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -567,24 +567,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 41c091028e4f..a0a42afbbb97 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -561,24 +561,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index e76c663a3d6a..038cd287d4d2 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -460,24 +460,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index 69cc9f956a58..9a6176a47030 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -444,24 +444,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index 976b981575c2..d60ff46a8151 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -45,4 +45,4 @@ module_spec=__spec__, ) for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) \ No newline at end of file + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 4b4927c32de4..6315372501de 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch from ...models.attention_processor import ( @@ -20,52 +19,52 @@ PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) - from ...utils import logging + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + class PAGMixin: r"""Mixin class for PAG.""" - + @staticmethod def _check_input_pag_applied_layer(layer): layer_splits = layer.split(".") if len(layer_splits) > 3: raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") - + if len(layer_splits) >= 1: if layer_splits[0] not in ["down", "mid", "up"]: - raise ValueError(f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}") - + raise ValueError( + f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}" + ) + if len(layer_splits) >= 2: if not layer_splits[1].startswith("block_"): raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'") - + if len(layer_splits) == 3: if not layer_splits[2].startswith("attentions_"): raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'") - def _set_attn_processor_pag_applied_layers(self, replace_processor): - def is_self_attn(name): return "attn1" in name and "to" not in name - def get_block_type(name): + def get_block_type(name): # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" return name.split(".")[0].split("_")[0] - + def get_block_index(name): # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "blocks_1" return f"block_{name.split('.')[1]}" - + def get_attn_index(name): # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" return f"attentions_{name.split('.')[3]}" - - - for drop_layer in self.pag_applied_layers: + for drop_layer in self.pag_applied_layers: self._check_input_pag_applied_layer(drop_layer) drop_layer_splits = drop_layer.split(".") @@ -78,7 +77,7 @@ def get_attn_index(name): continue if get_block_type(name) == block_type: target_modules.append(module) - + elif len(drop_layer_splits) == 2: # e.g. "down.block_1" block_type = drop_layer_splits[0] @@ -99,12 +98,16 @@ def get_attn_index(name): for name, module in self.unet.named_modules(): if not is_self_attn(name): continue - if get_block_type(name) == block_type and get_block_index(name) == block_index and get_attn_index(name) == attn_index: + if ( + get_block_type(name) == block_type + and get_block_index(name) == block_index + and get_attn_index(name) == attn_index + ): target_modules.append(module) - + if len(target_modules) == 0: logger.warning(f"Cannot find pag layer to set attention processor: {drop_layer}") - + for module in target_modules: module.processor = replace_processor @@ -139,7 +142,7 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui noise_pred_uncond, noise_pred_perturb = noise_pred.chunk(2) noise_pred = noise_pred_uncond + pag_scale * (noise_pred_uncond - noise_pred_perturb) return noise_pred - + def _prepare_perturbed_attention_guidance(self, input, uncond_input, do_classifier_free_guidance): input = torch.cat([input] * 2, dim=0) @@ -149,11 +152,11 @@ def _prepare_perturbed_attention_guidance(self, input, uncond_input, do_classifi def _reset_attn_processor(self): self._set_attn_processor_pag_applied_layers(AttnProcessor2_0()) - + @property def pag_scale(self): return self._pag_scale - + @property def pag_adaptive_scale(self): return self._pag_adaptive_scale diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index c66c21b4ae1e..4d33132f92ba 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -43,7 +43,6 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, - deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, @@ -52,9 +51,9 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor -from .pag_utils import PAGMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin if is_invisible_watermark_available(): @@ -249,7 +248,7 @@ def __init__( force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, enable_pag: bool = False, - pag_applied_layers: Union[str,List[str]] = "mid", # ["mid"],["down.1"] + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"],["down.1"] ): super().__init__() @@ -276,7 +275,7 @@ def __init__( self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - + if enable_pag and pag_applied_layers is not None: self._is_pag_enabled = True else: @@ -285,7 +284,6 @@ def __init__( if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] self.pag_applied_layers = pag_applied_layers - def encode_prompt( self, @@ -570,24 +568,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) @@ -1127,11 +1124,17 @@ def __call__( ) else: negative_add_time_ids = add_time_ids - + if self.do_perturbed_attention_guidance: - prompt_embeds = self._prepare_perturbed_attention_guidance(prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance) - add_text_embeds = self._prepare_perturbed_attention_guidance(add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance) - add_time_ids = self._prepare_perturbed_attention_guidance(add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance) + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -1150,12 +1153,14 @@ def __call__( batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) - + for i, image_embeds in enumerate(ip_adapter_image_embeds): if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: - image_embeds = self._prepare_perturbed_attention_guidance(image_embeds, negative_image_embeds, self.do_classifier_free_guidance) + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) @@ -1190,7 +1195,7 @@ def __call__( if self.do_perturbed_attention_guidance: self._set_pag_attn_processor(self.do_classifier_free_guidance) - + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1253,7 +1258,7 @@ def __call__( ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) - + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index a02501188036..b1ce3a6d5328 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -524,24 +524,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 76125183c753..5a7b449fa698 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -527,24 +527,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 9d34f6dc119b..e0850f0af2a4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -571,24 +571,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1033ea7a4721..3806593abe6a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -643,24 +643,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index 663b50e899f1..e5d993ee047a 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -510,24 +510,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index ad9291701941..b8db26d14fb8 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -482,24 +482,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c43486bcb952..3077be99b98f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -556,24 +556,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 70fc4691c60c..19f2da2a7bf5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -800,24 +800,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 447646e0752c..3da79a61419a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -493,24 +493,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index e43af9b9c047..2e51eb4bca24 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -570,24 +570,23 @@ def prepare_ip_adapter_image_embeds( single_ip_adapter_image, device, 1, output_hidden_state ) - - image_embeds.append(single_image_embeds[None,:]) + image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None,:]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) - + ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) From 1fa54df7ba9587bbfb9e62141624a22bbd88db06 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Jun 2024 09:27:53 +0000 Subject: [PATCH 17/44] style --- src/diffusers/pipelines/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 97d9556f393d..69fa87ca41cf 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -24,8 +24,8 @@ "deprecated": [], "latent_diffusion": [], "ledits_pp": [], - "pag": [], "marigold": [], + "pag": [], "stable_diffusion": [], "stable_diffusion_xl": [], } From ba366f03d13fac6fc1905f0bb94a6f9239c7bdd6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 9 Jun 2024 18:40:55 +0000 Subject: [PATCH 18/44] u[ --- src/diffusers/pipelines/auto_pipeline.py | 7 +- src/diffusers/pipelines/pag/pag_utils.py | 204 +++++++--- .../pipelines/pag/pipeline_pag_sd_xl.py | 22 +- tests/pipelines/pag/__init__.py | 0 tests/pipelines/pag/test_pag_sdxl.py | 359 ++++++++++++++++++ 5 files changed, 521 insertions(+), 71 deletions(-) create mode 100644 tests/pipelines/pag/__init__.py create mode 100644 tests/pipelines/pag/test_pag_sdxl.py diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 1c2de3114b2c..46ad7cd5be0c 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -335,7 +335,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if "controlnet" in kwargs: orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") if "enable_pag" in kwargs: - orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -390,7 +392,8 @@ def from_pipe(cls, pipeline, **kwargs): ) if "enable_pag" in kwargs: - if kwargs["enable_pag"] is not None: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, text_2_image_cls.__name__.replace("PAG", "").replace("Pipeline", "PAGPipeline"), diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 6315372501de..24e476676ddc 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -15,7 +15,6 @@ import torch from ...models.attention_processor import ( - AttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) @@ -30,7 +29,15 @@ class PAGMixin: @staticmethod def _check_input_pag_applied_layer(layer): + r""" + Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats: + "{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type` + can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be + in the format of "attentions_{j}". + """ + layer_splits = layer.split(".") + if len(layer_splits) > 3: raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.") @@ -48,79 +55,102 @@ def _check_input_pag_applied_layer(layer): if not layer_splits[2].startswith("attentions_"): raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'") - def _set_attn_processor_pag_applied_layers(self, replace_processor): - def is_self_attn(name): - return "attn1" in name and "to" not in name + def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + if do_classifier_free_guidance: + pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0() + else: + pag_attn_proc = PAGIdentitySelfAttnProcessor2_0() - def get_block_type(name): + def is_self_attn(module_name): + r""" + Check if the module is self-attention module based on its name. + """ + return "attn1" in module_name and "to" not in name + + def get_block_type(module_name): + r""" + Get the block type from the module name. can be "down", "mid", "up". + """ # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down" - return name.split(".")[0].split("_")[0] + return module_name.split(".")[0].split("_")[0] - def get_block_index(name): - # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "blocks_1" - return f"block_{name.split('.')[1]}" + def get_block_index(module_name): + r""" + Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g. + mid_block) and index is ommited from the name, it will be "block_0". + """ + # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1" + # mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0" + if "attentions" in module_name.split(".")[1]: + return "block_0" + else: + return f"block_{module_name.split('.')[1]}" - def get_attn_index(name): + def get_attn_index(module_name): + r""" + Get the attention index from the module name. can be "attentions_0", "attentions_1", ... + """ # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" - return f"attentions_{name.split('.')[3]}" + # mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0" + if "attentions" in module_name.split(".")[2]: + return f"attentions_{module_name.split('.')[3]}" + elif "attentions" in module_name.split(".")[1]: + return f"attentions_{module_name.split('.')[2]}" + + for pag_layer_input in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the unet model + target_modules = [] - for drop_layer in self.pag_applied_layers: - self._check_input_pag_applied_layer(drop_layer) - drop_layer_splits = drop_layer.split(".") + pag_layer_input_splits = pag_layer_input.split(".") - if len(drop_layer_splits) == 1: - # e.g. "mid" - block_type = drop_layer_splits[0] - target_modules = [] + if len(pag_layer_input_splits) == 1: + # when the layer input only contains block_type. e.g. "mid", "down", "up" + block_type = pag_layer_input_splits[0] for name, module in self.unet.named_modules(): - if not is_self_attn(name): - continue - if get_block_type(name) == block_type: + if is_self_attn(name) and get_block_type(name) == block_type: target_modules.append(module) - elif len(drop_layer_splits) == 2: - # e.g. "down.block_1" - block_type = drop_layer_splits[0] - block_index = drop_layer_splits[1] - target_modules = [] + elif len(pag_layer_input_splits) == 2: + # when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0" + block_type = pag_layer_input_splits[0] + block_index = pag_layer_input_splits[1] for name, module in self.unet.named_modules(): - if not is_self_attn(name): - continue - if get_block_type(name) == block_type and get_block_index(name) == block_index: + if ( + is_self_attn(name) + and get_block_type(name) == block_type + and get_block_index(name) == block_index + ): target_modules.append(module) - elif len(drop_layer_splits) == 3: - # e.g. "down.blocks_1.attentions_1" - block_type = drop_layer_splits[0] - block_index = drop_layer_splits[1] - attn_index = drop_layer_splits[2] - target_modules = [] + elif len(pag_layer_input_splits) == 3: + # when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1" + block_type = pag_layer_input_splits[0] + block_index = pag_layer_input_splits[1] + attn_index = pag_layer_input_splits[2] + for name, module in self.unet.named_modules(): - if not is_self_attn(name): - continue if ( - get_block_type(name) == block_type + is_self_attn(name) + and get_block_type(name) == block_type and get_block_index(name) == block_index and get_attn_index(name) == attn_index ): target_modules.append(module) if len(target_modules) == 0: - logger.warning(f"Cannot find pag layer to set attention processor: {drop_layer}") + raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}") for module in target_modules: - module.processor = replace_processor - - def _set_pag_attn_processor(self, do_classifier_free_guidance): - if do_classifier_free_guidance: - self._set_attn_processor_pag_applied_layers(PAGCFGIdentitySelfAttnProcessor2_0()) - else: - self._set_attn_processor_pag_applied_layers(PAGIdentitySelfAttnProcessor2_0()) - - def _reset_attn_processor(self): - self._set_attn_processor_pag_applied_layers(AttnProcessor2_0()) + module.processor = pag_attn_proc def _get_pag_scale(self, t): + r""" + Get the scale factor for the perturbed attention guidance. + """ + if self.do_pag_adaptive_scaling: signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) if signal_scale < 0: @@ -130,6 +160,18 @@ def _get_pag_scale(self, t): return self.pag_scale def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + r""" + Apply perturbed attention guidance to the noise prediction. + + Args: + noise_pred (torch.Tensor): The noise prediction tensor. + do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. + guidance_scale (float): The scale factor for the guidance term. + t (int): The current time step. + + Returns: + torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. + """ pag_scale = self._get_pag_scale(t) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) @@ -143,28 +185,76 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui noise_pred = noise_pred_uncond + pag_scale * (noise_pred_uncond - noise_pred_perturb) return noise_pred - def _prepare_perturbed_attention_guidance(self, input, uncond_input, do_classifier_free_guidance): - input = torch.cat([input] * 2, dim=0) + def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): + """ + Prepares the perturbed attention guidance for the PAG model. + + Args: + cond (torch.Tensor): The conditional input tensor. + uncond (torch.Tensor): The unconditional input tensor. + do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. + + Returns: + torch.Tensor: The prepared perturbed attention guidance tensor. + """ + + cond = torch.cat([cond] * 2, dim=0) if do_classifier_free_guidance: - input = torch.cat([uncond_input, input], dim=0) - return input + cond = torch.cat([uncond, cond], dim=0) + return cond - def _reset_attn_processor(self): - self._set_attn_processor_pag_applied_layers(AttnProcessor2_0()) + def set_pag_applied_layers(self, pag_applied_layers): + r""" + set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + """ + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + + for pag_layer in pag_applied_layers: + self._check_input_pag_applied_layer(pag_layer) + + self.pag_applied_layers = pag_applied_layers @property def pag_scale(self): + """ + Get the scale factor for the perturbed attention guidance. + """ return self._pag_scale @property def pag_adaptive_scale(self): + """ + Get the adaptive scale factor for the perturbed attention guidance. + """ return self._pag_adaptive_scale @property def do_pag_adaptive_scaling(self): - return self._is_pag_enabled and self._pag_adaptive_scale > 0 + """ + Check if the adaptive scaling is enabled for the perturbed attention guidance. + """ + return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property def do_perturbed_attention_guidance(self): - return self._is_pag_enabled and self._pag_scale > 0 + """ + Check if the perturbed attention guidance is enabled. + """ + return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + def pag_attn_processors(self): + r""" + Returns: + `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model + with the key as the name of the layer. + """ + + processors = {} + for name, proc in self.unet.attn_processors.items(): + if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): + processors[name] = proc + return processors diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 4d33132f92ba..4df267c63305 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -247,8 +247,7 @@ def __init__( feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, - enable_pag: bool = False, - pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"],["down.1"] + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"],["down.block_1"],["up.block_0.attentions_0"] ): super().__init__() @@ -276,14 +275,7 @@ def __init__( else: self.watermark = None - if enable_pag and pag_applied_layers is not None: - self._is_pag_enabled = True - else: - self._is_pag_enabled = False - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - self.pag_applied_layers = pag_applied_layers + self.set_pag_applied_layers(pag_applied_layers) def encode_prompt( self, @@ -1155,8 +1147,10 @@ def __call__( ) for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance @@ -1194,7 +1188,11 @@ def __call__( ).to(device=device, dtype=latents.dtype) if self.do_perturbed_attention_guidance: - self._set_pag_attn_processor(self.do_classifier_free_guidance) + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1311,7 +1309,7 @@ def __call__( self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: - self._reset_attn_processor() + self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image,) diff --git a/tests/pipelines/pag/__init__.py b/tests/pipelines/pag/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py new file mode 100644 index 000000000000..9a9e13607cad --- /dev/null +++ b/tests/pipelines/pag/test_pag_sdxl.py @@ -0,0 +1,359 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + AutoPipelineForText2Image, + EulerDiscreteScheduler, + StableDiffusionXLPAGPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import ( + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineFromPipeTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionXLPAGPipelineFastTests( + PipelineTesterMixin, + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineFromPipeTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionXLPAGPipeline + params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + + def get_dummy_components(self, time_cond_proj_dim=None): + # Copied from tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl.StableDiffusionXLPipelineFastTests.get_dummy_components + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(2, 4), + layers_per_block=2, + time_cond_proj_dim=time_cond_proj_dim, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + norm_num_groups=1, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "pag_scale": 0.9, + "output_type": "np", + } + return inputs + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline + pipe_sd = self.original_pipeline_class(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + def test_pag_applied_layers(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + # pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers + all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k] + original_attn_procs = pipe.unet.attn_processors + pag_layers = ["mid", "down", "up"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) + + # pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.attentions_0"] should apply to all self-attention layers in mid_block, i.e. + # mid_block.attentions.0.transformer_blocks.0.attn1.processor + # mid_block.attentions.0.transformer_blocks.1.attn1.processor + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["mid"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set( + "mid_block.attentions.0.transformer_blocks.0.attn1.processor", + "mid_block.attentions.0.transformer_blocks.1.attn1.processor", + ) + + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["mid.block_0"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set( + "mid_block.attentions.0.transformer_blocks.0.attn1.processor", + "mid_block.attentions.0.transformer_blocks.1.attn1.processor", + ) + + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["mid.block_0.attentions_0"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert set(pipe.pag_attn_processors) == set( + "mid_block.attentions.0.transformer_blocks.0.attn1.processor", + "mid_block.attentions.0.transformer_blocks.1.attn1.processor", + ) + + # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["mid.block_0.attentions_1"] + with self.assertRaises(ValueError): + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + + # pag_applied_layers = "down" should apply to all self-attention layers in down_blocks + # down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor + # down_blocks.1.attentions.0.transformer_blocks.1.attn1.processor + # down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor + # down_blocks.1.attentions.1.transformer_blocks.1.attn1.processor + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["down"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert len(pipe.pag_attn_processors) == 4 + + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["down.block_0"] + with self.assertRaises(ValueError): + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["down.block_1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert len(pipe.pag_attn_processors) == 4 + + pipe.unet.set_attn_processor(original_attn_procs.copy()) + pag_layers = ["down.block_1.attentions_1"] + pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) + assert len(pipe.pag_attn_processors) == 2 + + def test_pag_inference(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 64, + 64, + 3, + ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}" + expected_slice = np.array( + [0.55341685, 0.55503535, 0.47299808, 0.43274558, 0.4965323, 0.46310428, 0.51455414, 0.5015592, 0.46913484] + ) + + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + +@slow +@require_torch_gpu +class StableDiffusionXLPAGPipelineIntegrationTests(unittest.TestCase): + pipeline_class = StableDiffusionXLPAGPipeline + repo_id = "stabilityai/stable-diffusion-xl-base-1.0" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0): + generator = torch.Generator(device=generator_device).manual_seed(seed) + inputs = { + "prompt": "a polar bear sitting in a chair drinking a milkshake", + "negative_prompt": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + "generator": generator, + "num_inference_steps": 3, + "guidance_scale": guidance_scale, + "pag_scale": 3.0, + "output_type": "np", + } + return inputs + + def test_pag_cfg(self): + pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 1024, 1024, 3) + expected_slice = np.array( + [0.3123679, 0.31725878, 0.32026544, 0.327533, 0.3266391, 0.3303998, 0.33544615, 0.34181812, 0.34102726] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" + + def test_pag_cfg_uncond(self): + pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device, guidance_scale=0.0) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 1024, 1024, 3) + expected_slice = np.array( + [0.47400922, 0.48650584, 0.4839625, 0.4724013, 0.4890427, 0.49544555, 0.51707107, 0.54299414, 0.5224372] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" From d5a67612d7dbbeeddc807fdf505526cd6a00bac2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Jun 2024 01:38:33 +0000 Subject: [PATCH 19/44] up --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 3 +- src/diffusers/pipelines/auto_pipeline.py | 20 +- src/diffusers/pipelines/pag/__init__.py | 3 +- .../pipelines/pag/pipeline_pag_sd_xl.py | 8 +- .../pag/pipeline_pag_sd_xl_inpaint.py | 1740 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/pag/test_pag_sdxl.py | 26 +- tests/pipelines/pag/test_pag_sdxl_inpaint.py | 344 ++++ 9 files changed, 2142 insertions(+), 19 deletions(-) create mode 100644 src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py create mode 100644 tests/pipelines/pag/test_pag_sdxl_inpaint.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5ef6613f283c..222cd9b71b4f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -309,6 +309,7 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", "StableUnCLIPImg2ImgPipeline", @@ -694,6 +695,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 69fa87ca41cf..f044b0948ad7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -140,6 +140,7 @@ _import_structure["pag"].extend( [ "StableDiffusionXLPAGPipeline", + "StableDiffusionXLPAGInpaintPipeline", ] ) _import_structure["controlnet_xs"].extend( @@ -468,7 +469,7 @@ MarigoldNormalsPipeline, ) from .musicldm import MusicLDMPipeline - from .pag import StableDiffusionXLPAGPipeline + from .pag import StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 46ad7cd5be0c..17d2e134b340 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -45,7 +45,7 @@ ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline -from .pag import StableDiffusionXLPAGPipeline +from .pag import StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline from .stable_diffusion import ( @@ -103,6 +103,7 @@ ("kandinsky22", KandinskyV22InpaintCombinedPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), + ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ] ) @@ -900,6 +901,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if "controlnet" in kwargs: orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) @@ -956,6 +961,19 @@ def from_pipe(cls, pipeline, **kwargs): inpainting_cls.__name__.replace("ControlNetInpaintPipeline", "InpaintPipeline"), ) + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + inpainting_cls = _get_task_class( + AUTO_INPAINT_PIPELINES_MAPPING, + inpainting_cls.__name__.replace("PAG", "").replace("InpaintPipeline", "PAGInpaintPipeline"), + ) + else: + inpainting_cls = _get_task_class( + AUTO_INPAINT_PIPELINES_MAPPING, + inpainting_cls.__name__.replace("PAGInpaintPipeline", "InpaintPipeline"), + ) + # define expected module and optional kwargs given the pipeline signature expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls) diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index d60ff46a8151..f2ceb7a8a0f0 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] + _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -33,7 +34,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl import StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline else: import sys diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 4df267c63305..fe7c9c59a5ce 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -854,7 +854,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - pag_scale: float = 0.0, + pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" @@ -993,6 +993,12 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. Examples: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py new file mode 100644 index 000000000000..bc5176c2ee21 --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -0,0 +1,1740 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... use_safetensors=True, + ... ) + >>> pipe.to("cuda") + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is not None: + masked_image = masked_image_latents + elif init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if self.denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 11.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0761524e1e02..66e0f8d5dbb8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1397,6 +1397,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLPAGInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py index 9a9e13607cad..9e607d42e8dc 100644 --- a/tests/pipelines/pag/test_pag_sdxl.py +++ b/tests/pipelines/pag/test_pag_sdxl.py @@ -26,6 +26,7 @@ AutoPipelineForText2Image, EulerDiscreteScheduler, StableDiffusionXLPAGPipeline, + StableDiffusionXLPipeline, UNet2DConditionModel, ) from diffusers.utils.testing_utils import ( @@ -160,8 +161,8 @@ def test_pag_disable_enable(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - # base pipeline - pipe_sd = self.original_pipeline_class(**components) + # base pipeline (expect same output when pag is disabled) + pipe_sd = StableDiffusionXLPipeline(**components) pipe_sd = pipe_sd.to(device) pipe_sd.set_progress_bar_config(disable=None) @@ -214,29 +215,24 @@ def test_pag_applied_layers(self): # pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.attentions_0"] should apply to all self-attention layers in mid_block, i.e. # mid_block.attentions.0.transformer_blocks.0.attn1.processor # mid_block.attentions.0.transformer_blocks.1.attn1.processor + all_self_attn_mid_layers = [ + "mid_block.attentions.0.transformer_blocks.0.attn1.processor", + "mid_block.attentions.0.transformer_blocks.1.attn1.processor", + ] pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["mid"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert set(pipe.pag_attn_processors) == set( - "mid_block.attentions.0.transformer_blocks.0.attn1.processor", - "mid_block.attentions.0.transformer_blocks.1.attn1.processor", - ) + assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["mid.block_0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert set(pipe.pag_attn_processors) == set( - "mid_block.attentions.0.transformer_blocks.0.attn1.processor", - "mid_block.attentions.0.transformer_blocks.1.attn1.processor", - ) + assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) pipe.unet.set_attn_processor(original_attn_procs.copy()) pag_layers = ["mid.block_0.attentions_0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) - assert set(pipe.pag_attn_processors) == set( - "mid_block.attentions.0.transformer_blocks.0.attn1.processor", - "mid_block.attentions.0.transformer_blocks.1.attn1.processor", - ) + assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model pipe.unet.set_attn_processor(original_attn_procs.copy()) @@ -341,7 +337,7 @@ def test_pag_cfg(self): np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 ), f"output is different from expected, {image_slice.flatten()}" - def test_pag_cfg_uncond(self): + def test_pag_uncond(self): pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) pipeline.enable_model_cpu_offload() pipeline.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py new file mode 100644 index 000000000000..92cfa8c85667 --- /dev/null +++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + AutoencoderKL, + AutoPipelineForInpainting, + EulerDiscreteScheduler, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPAGInpaintPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, +) +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineFromPipeTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionXLPAGInpaintPipelineFastTests( + PipelineTesterMixin, + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineFromPipeTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionXLPAGInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + image_latents_params = frozenset([]) + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union( + {"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"} + ) + + def get_dummy_components( + self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False + ): + # copied from tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + time_cond_proj_dim=time_cond_proj_dim, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=72 if requires_aesthetics_score else 80, # 5 * 8 + 32 + cross_attention_dim=64 if not skip_first_text_encoder else 32, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=32, + image_size=224, + projection_dim=32, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + feature_extractor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder if not skip_first_text_encoder else None, + "tokenizer": tokenizer if not skip_first_text_encoder else None, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "image_encoder": image_encoder, + "feature_extractor": feature_extractor, + "requires_aesthetics_score": requires_aesthetics_score, + } + return components + + def get_dummy_inputs(self, device, seed=0): + # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + # create mask + image[8:, 8:, :] = 255 + mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64)) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": init_image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "strength": 1.0, + "pag_scale": 0.9, + "output_type": "np", + } + return inputs + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(requires_aesthetics_score=True) + + # base pipeline + pipe_sd = StableDiffusionXLInpaintPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + def test_pag_inference(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(requires_aesthetics_score=True) + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 64, + 64, + 3, + ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}" + expected_slice = np.array( + [0.8115454, 0.53986573, 0.5825281, 0.6028964, 0.67128646, 0.7046922, 0.6418713, 0.5933924, 0.5154763] + ) + + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}" + + +@slow +@require_torch_gpu +class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase): + repo_id = "stabilityai/stable-diffusion-xl-base-1.0" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0): + img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + init_image = load_image(img_url).convert("RGB") + mask_image = load_image(mask_url).convert("RGB") + + generator = torch.Generator(device=generator_device).manual_seed(seed) + inputs = { + "prompt": "A majestic tiger sitting on a bench", + "generator": generator, + "image": init_image, + "mask_image": mask_image, + "strength": 0.8, + "num_inference_steps": 3, + "guidance_scale": guidance_scale, + "pag_scale": 3.0, + "output_type": "np", + } + return inputs + + def test_pag_cfg(self): + pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 1024, 1024, 3) + expected_slice = np.array( + [0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" + + def test_pag_uncond(self): + pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device, guidance_scale=0.0) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 1024, 1024, 3) + expected_slice = np.array( + [0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" From 854b70e318225e24dfe27a93f1520801f57875b2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Jun 2024 02:54:58 +0000 Subject: [PATCH 20/44] fix --- src/diffusers/pipelines/pag/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index f2ceb7a8a0f0..f15db0b9bee1 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -34,7 +34,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_pag_sd_xl import StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline else: import sys From 9e4c1b63ccea0ce35155058e474f5e09ba7261be Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Jun 2024 07:38:57 +0000 Subject: [PATCH 21/44] add controlnet pag --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 7 +- src/diffusers/pipelines/auto_pipeline.py | 31 +- src/diffusers/pipelines/pag/__init__.py | 102 +- .../pag/pipeline_pag_controlnet_sd_xl.py | 1632 +++++++++++++++++ .../pipelines/pag/test_pag_controlnet_sdxl.py | 271 +++ 6 files changed, 1982 insertions(+), 63 deletions(-) create mode 100644 src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py create mode 100644 tests/pipelines/pag/test_pag_controlnet_sdxl.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 222cd9b71b4f..16e464f20209 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -304,6 +304,7 @@ "StableDiffusionXLAdapterPipeline", "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPAGPipeline", "StableDiffusionXLControlNetPipeline", "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", @@ -690,6 +691,7 @@ StableDiffusionXLAdapterPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetXSPipeline, StableDiffusionXLImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index f044b0948ad7..cfd9ba622f70 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -141,6 +141,7 @@ [ "StableDiffusionXLPAGPipeline", "StableDiffusionXLPAGInpaintPipeline", + "StableDiffusionXLControlNetPAGPipeline", ] ) _import_structure["controlnet_xs"].extend( @@ -469,7 +470,11 @@ MarigoldNormalsPipeline, ) from .musicldm import MusicLDMPipeline - from .pag import StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline + from .pag import ( + StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLPAGInpaintPipeline, + StableDiffusionXLPAGPipeline, + ) from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 17d2e134b340..665515bee595 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -45,7 +45,11 @@ ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline -from .pag import StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline +from .pag import ( + StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLPAGInpaintPipeline, + StableDiffusionXLPAGPipeline, +) from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline from .stable_diffusion import ( @@ -77,6 +81,7 @@ ("pixart-alpha", PixArtAlphaPipeline), ("pixart-sigma", PixArtSigmaPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline), + ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline), ] ) @@ -338,7 +343,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: - orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") + orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -380,29 +385,31 @@ def from_pipe(cls, pipeline, **kwargs): # derive the pipeline class to instantiate text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name) - if "controlnet" in kwargs: - if kwargs["controlnet"] is not None: + to_replace = "Pipeline" + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("ControlNet", "").replace("Pipeline", "ControlNetPipeline"), + text_2_image_cls.__name__.replace("PAG", "").replace(to_replace, "PAG" + to_replace), ) + to_replace = "PAG" + to_replace else: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("ControlNetPipeline", "Pipeline"), + text_2_image_cls.__name__.replace("PAG", ""), ) - - if "enable_pag" in kwargs: - enable_pag = kwargs.pop("enable_pag") - if enable_pag: + if "controlnet" in kwargs: + if kwargs["controlnet"] is not None: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("PAG", "").replace("Pipeline", "PAGPipeline"), + text_2_image_cls.__name__.replace("ControlNet", "").replace(to_replace, "ControlNet" + to_replace), ) + to_replace = "ControlNet" + to_replace else: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("PAGPipeline", "Pipeline"), + text_2_image_cls.__name__.replace("ControlNet", ""), ) # define expected module and optional kwargs given the pipeline signature diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index f15db0b9bee1..782623d0a1dd 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -1,50 +1,52 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] - _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] - - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * - else: - from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline - from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] + + _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] + _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline + from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py new file mode 100644 index 000000000000..a7c0b8d67afd --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -0,0 +1,1632 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from ..controlnet.multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlNetPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["down.block_2", "up.block_1.attentions_0"], "mid" + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if not guess_mode: + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image + + image = images if isinstance(image, list) else images[0] + + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py new file mode 100644 index 000000000000..f42afb7cea47 --- /dev/null +++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + EulerDiscreteScheduler, + StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, +) +from diffusers.utils.torch_utils import randn_tensor + +from ..pipeline_params import ( + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineFromPipeTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionXLControlNetPAGPipelineFastTests( + PipelineTesterMixin, + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineFromPipeTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionXLControlNetPAGPipeline + params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + + def get_dummy_components(self, time_cond_proj_dim=None): + # Copied from tests.pipelines.controlnet.test_controlnet_sdxl.StableDiffusionXLControlNetPipelineFastTests.get_dummy_components + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + time_cond_proj_dim=time_cond_proj_dim, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "pag_scale": 3.0, + "output_type": "np", + "image": image, + } + + return inputs + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + # base pipeline (expect same output when pag is disabled) + pipe_sd = StableDiffusionXLControlNetPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + def test_pag_cfg(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 64, + 64, + 3, + ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}" + expected_slice = np.array( + [0.6819614, 0.5551478, 0.5499094, 0.5769566, 0.53942275, 0.5707505, 0.41131154, 0.47833863, 0.49982738] + ) + + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}" + + def test_pag_uncond(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["guidance_scale"] = 0.0 + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 64, + 64, + 3, + ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}" + expected_slice = np.array( + [0.66685176, 0.53207266, 0.5541569, 0.5912994, 0.5368312, 0.58433825, 0.42607725, 0.46805605, 0.5098659] + ) + + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}" From 623d237c2b29ab831c92916a7ae37023c26d309e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Jun 2024 07:39:18 +0000 Subject: [PATCH 22/44] copy --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 66e0f8d5dbb8..e03dfabc3750 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1322,6 +1322,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLControlNetPAGPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From f30c2bcf3082d4a1abe82ddd01b11a0e6ea05c21 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Jun 2024 09:09:02 +0000 Subject: [PATCH 23/44] add from pipe test for pag + controlnet --- src/diffusers/pipelines/auto_pipeline.py | 23 ++++++------ tests/pipelines/test_pipelines_auto.py | 48 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 665515bee595..a52db0f7b454 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -385,31 +385,30 @@ def from_pipe(cls, pipeline, **kwargs): # derive the pipeline class to instantiate text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name) - to_replace = "Pipeline" - if "enable_pag" in kwargs: - enable_pag = kwargs.pop("enable_pag") - if enable_pag: + if "controlnet" in kwargs: + if kwargs["controlnet"] is not None: + to_replace = "PAGPipeline" if "PAG" in text_2_image_cls.__name__ else "Pipeline" text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("PAG", "").replace(to_replace, "PAG" + to_replace), + text_2_image_cls.__name__.replace("ControlNet", "").replace(to_replace, "ControlNet" + to_replace), ) - to_replace = "PAG" + to_replace else: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("PAG", ""), + text_2_image_cls.__name__.replace("ControlNet", ""), ) - if "controlnet" in kwargs: - if kwargs["controlnet"] is not None: + + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("ControlNet", "").replace(to_replace, "ControlNet" + to_replace), + text_2_image_cls.__name__.replace("PAG", "").replace("Pipeline", "PAGPipeline"), ) - to_replace = "ControlNet" + to_replace else: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - text_2_image_cls.__name__.replace("ControlNet", ""), + text_2_image_cls.__name__.replace("PAG", ""), ) # define expected module and optional kwargs given the pipeline signature diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py index 284cdb0ad006..6878076d3f93 100644 --- a/tests/pipelines/test_pipelines_auto.py +++ b/tests/pipelines/test_pipelines_auto.py @@ -135,6 +135,54 @@ def test_from_pipe_controlnet_text2img(self): assert pipe.__class__.__name__ == "StableDiffusionPipeline" assert "controlnet" not in pipe.components + def test_from_pipe_pag_controlnet(self): + pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") + + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=True) + assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert "controlnet" in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None) + assert pipe.__class__.__name__ == "StableDiffusionXLPAGPipeline" + assert "controlnet" not in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" + assert "controlnet" not in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, enable_pag=True) + assert pipe.__class__.__name__ == "StableDiffusionXLPAGPipeline" + assert "controlnet" not in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + assert "controlnet" in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None) + assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" + assert "controlnet" not in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" + assert "controlnet" not in pipe.components + + pipe = AutoPipelineForInpainting.from_pipe(pipe, enable_pag=True) + assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + assert "controlnet" not in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + assert "controlnet" in pipe.components + + pipe = AutoPipelineForInpainting.from_pipe(pipe, controlnet=None, enable_pag=True) + assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + assert "controlnet" not in pipe.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet) + assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert "controlnet" in pipe.components + def test_from_pipe_controlnet_img2img(self): pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe") controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") From 1df43913cc88d21987d0b2b5e6bf11fcd5caebe1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Jun 2024 18:02:05 +0000 Subject: [PATCH 24/44] up --- tests/pipelines/test_pipelines_auto.py | 144 +++++++++++++++++++------ 1 file changed, 110 insertions(+), 34 deletions(-) diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py index 6878076d3f93..268fa052923a 100644 --- a/tests/pipelines/test_pipelines_auto.py +++ b/tests/pipelines/test_pipelines_auto.py @@ -123,66 +123,142 @@ def test_kwargs_local_files_only(self): shutil.rmtree(tmpdirname.parent.parent) - def test_from_pipe_controlnet_text2img(self): - pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe") + def test_from_pipe_pag_text2img(self): + # test from StableDiffusionXLPipeline + pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet) - assert pipe.__class__.__name__ == "StableDiffusionControlNetPipeline" - assert "controlnet" in pipe.components + # - test `enable_pag` flag + pipe_pag = AutoPipelineForText2Image.from_pipe(pipe, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGPipeline" + assert "controlnet" not in pipe_pag.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None) - assert pipe.__class__.__name__ == "StableDiffusionPipeline" + pipe = AutoPipelineForText2Image.from_pipe(pipe, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" assert "controlnet" not in pipe.components - def test_from_pipe_pag_controlnet(self): - pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") - controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") + # - test `enabe_pag` + `controlnet` flag + pipe_control_pag = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert "controlnet" in pipe_control_pag.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=True) - assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" - assert "controlnet" in pipe.components + pipe_control = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=False) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + assert "controlnet" in pipe_control.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None) - assert pipe.__class__.__name__ == "StableDiffusionXLPAGPipeline" - assert "controlnet" not in pipe.components + pipe_pag = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGPipeline" + assert "controlnet" not in pipe_pag.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, enable_pag=False) + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None, enable_pag=False) assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" assert "controlnet" not in pipe.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, enable_pag=True) - assert pipe.__class__.__name__ == "StableDiffusionXLPAGPipeline" - assert "controlnet" not in pipe.components + # test from StableDiffusionXLControlNetPipeline + # - test `enable_pag` flag + pipe_control_pag = AutoPipelineForText2Image.from_pipe(pipe_control, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert "controlnet" in pipe_control_pag.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=False) - assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPipeline" - assert "controlnet" in pipe.components + pipe_control = AutoPipelineForText2Image.from_pipe(pipe_control, enable_pag=False) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + assert "controlnet" in pipe_control.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None) - assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" - assert "controlnet" not in pipe.components + # - test `enable_pag` + `controlnet` flag + pipe_control_pag = AutoPipelineForText2Image.from_pipe(pipe_control, controlnet=controlnet, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert "controlnet" in pipe_control_pag.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None, enable_pag=False) + pipe_control = AutoPipelineForText2Image.from_pipe(pipe_control, controlnet=controlnet, enable_pag=False) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + assert "controlnet" in pipe_control.components + + pipe_pag = AutoPipelineForText2Image.from_pipe(pipe_control, controlnet=None, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGPipeline" + assert "controlnet" not in pipe_pag.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe_control, controlnet=None, enable_pag=False) assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" assert "controlnet" not in pipe.components - pipe = AutoPipelineForInpainting.from_pipe(pipe, enable_pag=True) - assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + # test from StableDiffusionXLControlNetPAGPipeline + # - test `enable_pag` flag + pipe_control_pag = AutoPipelineForText2Image.from_pipe(pipe_control_pag, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert "controlnet" in pipe_control_pag.components + + pipe_control = AutoPipelineForText2Image.from_pipe(pipe_control_pag, enable_pag=False) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + assert "controlnet" in pipe_control.components + + # - test `enable_pag` + `controlnet` flag + pipe_control_pag = AutoPipelineForText2Image.from_pipe( + pipe_control_pag, controlnet=controlnet, enable_pag=True + ) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert "controlnet" in pipe_control_pag.components + + pipe_control = AutoPipelineForText2Image.from_pipe(pipe_control_pag, controlnet=controlnet, enable_pag=False) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + assert "controlnet" in pipe_control.components + + pipe_pag = AutoPipelineForText2Image.from_pipe(pipe_control_pag, controlnet=None, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGPipeline" + assert "controlnet" not in pipe_pag.components + + pipe = AutoPipelineForText2Image.from_pipe(pipe_control_pag, controlnet=None, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" assert "controlnet" not in pipe.components - pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet, enable_pag=False) + pipe = AutoPipelineForText2Image.from_pipe(pipe_control_pag, enable_pag=False) assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPipeline" assert "controlnet" in pipe.components - pipe = AutoPipelineForInpainting.from_pipe(pipe, controlnet=None, enable_pag=True) - assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" - assert "controlnet" not in pipe.components + def test_from_pipe_pag_inpaint(self): + # test from tableDiffusionXLPAGInpaintPipeline + pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + # - test `enable_pag` flag + pipe_pag = AutoPipelineForInpainting.from_pipe(pipe, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + + pipe = AutoPipelineForInpainting.from_pipe(pipe, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline" + + # testing from StableDiffusionXLPAGInpaintPipeline + # - test `enable_pag` flag + pipe_pag = AutoPipelineForInpainting.from_pipe(pipe_pag, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + + pipe = AutoPipelineForInpainting.from_pipe(pipe_pag, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline" + + def test_from_pipe_pag_new_task(self): + # for from_pipe_new_task we only need to make sure it can map to the same pipeline from a different task, + # i.e. no need to test `enable_pag` + `controlnet` flag because it is already tested in `test_from_pipe_pag_text2img` and `test_from_pipe_pag_inpaint`etc + pipe_pag_text2img = AutoPipelineForText2Image.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-xl-pipe", enable_pag=True + ) + + # text2img pag -> inpaint pag + pipe_pag_inpaint = AutoPipelineForInpainting.from_pipe(pipe_pag_text2img) + assert pipe_pag_inpaint.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + + # inpaint pag -> text2img pag + pipe_pag_text2img = AutoPipelineForText2Image.from_pipe(pipe_pag_inpaint) + assert pipe_pag_text2img.__class__.__name__ == "StableDiffusionXLPAGPipeline" + + def test_from_pipe_controlnet_text2img(self): + pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe") + controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=controlnet) - assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + assert pipe.__class__.__name__ == "StableDiffusionControlNetPipeline" assert "controlnet" in pipe.components + pipe = AutoPipelineForText2Image.from_pipe(pipe, controlnet=None) + assert pipe.__class__.__name__ == "StableDiffusionPipeline" + assert "controlnet" not in pipe.components + def test_from_pipe_controlnet_img2img(self): pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe") controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") From 191505e958bf384a233ec15ebcf9fa269f28395b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jun 2024 05:32:39 +0200 Subject: [PATCH 25/44] support guess mode --- .../pag/pipeline_pag_controlnet_sd_xl.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index a7c0b8d67afd..e8094042fcd8 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -632,11 +632,18 @@ def check_inputs( ip_adapter_image=None, ip_adapter_image_embeds=None, negative_pooled_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, + controlnet_conditioning_scale=None, + control_guidance_start=None, + control_guidance_end=None, callback_on_step_end_tensor_inputs=None, - ): + guidance_scale=None, + pag_scale=None, + guess_mode=None, + ): + if guess_mode and pag_scale > 0 and guidance_scale > 1: + raise ValueError( + "guess_mode cannot work with PAG and guidance scale together. Please set either `pag_scale` or `guidance_scale`; or set `guess_mode` to False." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -1229,6 +1236,9 @@ def __call__( control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, + guidance_scale, + pag_scale, + guess_mode, ) self._guidance_scale = guidance_scale @@ -1432,6 +1442,11 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds + + # for guess_mode, we do not need to apply guidance on controlnet inputs + if guess_mode: + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( @@ -1451,6 +1466,11 @@ def __call__( prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if not guess_mode: + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1491,22 +1511,13 @@ def __call__( latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - # controlnet(s) inference - if guess_mode and self.do_classifier_free_guidance: + if guess_mode: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - controlnet_added_cond_kwargs = { - "text_embeds": add_text_embeds.chunk(2)[1], - "time_ids": add_time_ids.chunk(2)[1], - } else: control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -1527,7 +1538,7 @@ def __call__( return_dict=False, ) - if guess_mode and self.do_classifier_free_guidance: + if guess_mode and (self.do_classifier_free_guidance or self.do_perturbed_attention_guidance): # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. From 58b83308013a0f32234b289d2c77348cc6e62090 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jun 2024 03:33:49 +0000 Subject: [PATCH 26/44] style --- src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index e8094042fcd8..f91d11fccaf7 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -639,7 +639,7 @@ def check_inputs( guidance_scale=None, pag_scale=None, guess_mode=None, - ): + ): if guess_mode and pag_scale > 0 and guidance_scale > 1: raise ValueError( "guess_mode cannot work with PAG and guidance scale together. Please set either `pag_scale` or `guidance_scale`; or set `guess_mode` to False." @@ -1442,7 +1442,7 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds - + # for guess_mode, we do not need to apply guidance on controlnet inputs if guess_mode: controlnet_prompt_embeds = prompt_embeds From 71cf2f7595982949960d03cc9d0d96974373114b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Jun 2024 07:54:06 +0000 Subject: [PATCH 27/44] add pag + img2img --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 26 +- src/diffusers/pipelines/pag/__init__.py | 2 + .../pag/pipeline_pag_sd_xl_img2img.py | 1513 +++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/pag/test_pag_sdxl_img2img.py | 338 ++++ tests/pipelines/pag/test_pag_sdxl_inpaint.py | 4 +- tests/pipelines/test_pipelines_auto.py | 64 + 9 files changed, 1962 insertions(+), 4 deletions(-) create mode 100644 src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py create mode 100644 tests/pipelines/pag/test_pag_sdxl_img2img.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 16e464f20209..ba9dd72c8d71 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -310,6 +310,7 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", @@ -697,6 +698,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index cfd9ba622f70..f218784d90c3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -142,6 +142,7 @@ "StableDiffusionXLPAGPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLControlNetPAGPipeline", + "StableDiffusionXLPAGImg2ImgPipeline", ] ) _import_structure["controlnet_xs"].extend( @@ -472,6 +473,7 @@ from .musicldm import MusicLDMPipeline from .pag import ( StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, ) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index a52db0f7b454..a83f2b7d9ddc 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -47,6 +47,7 @@ from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .pag import ( StableDiffusionXLControlNetPAGPipeline, + StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, ) @@ -95,6 +96,7 @@ ("kandinsky3", Kandinsky3Img2ImgPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), + ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), ] ) @@ -631,6 +633,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): if "controlnet" in kwargs: orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) @@ -676,16 +682,32 @@ def from_pipe(cls, pipeline, **kwargs): if "controlnet" in kwargs: if kwargs["controlnet"] is not None: + to_replace = "Img2ImgPipeline" + if "PAG" in image_2_image_cls.__name__: + to_replace = "PAG" + to_replace image_2_image_cls = _get_task_class( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, image_2_image_cls.__name__.replace("ControlNet", "").replace( - "Img2ImgPipeline", "ControlNetImg2ImgPipeline" + to_replace, "ControlNet" + to_replace ), ) else: image_2_image_cls = _get_task_class( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, - image_2_image_cls.__name__.replace("ControlNetImg2ImgPipeline", "Img2ImgPipeline"), + image_2_image_cls.__name__.replace("ControlNet", ""), + ) + + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + image_2_image_cls = _get_task_class( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + image_2_image_cls.__name__.replace("PAG", "").replace("Img2ImgPipeline", "PAGImg2ImgPipeline"), + ) + else: + image_2_image_cls = _get_task_class( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + image_2_image_cls.__name__.replace("PAG", ""), ) # define expected module and optional kwargs given the pipeline signature diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index 782623d0a1dd..5695d75566fa 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -26,6 +26,7 @@ _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] + _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -37,6 +38,7 @@ else: from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline else: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py new file mode 100644 index 000000000000..a62787b486c0 --- /dev/null +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -0,0 +1,1513 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + strength, + num_inference_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=self.device, dtype=dtype) + latents_std = latents_std.to(device=self.device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + num_inference_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e03dfabc3750..a71e4a1526a8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1412,6 +1412,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionXLPAGInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py new file mode 100644 index 000000000000..e62c4ccf370d --- /dev/null +++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py @@ -0,0 +1,338 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import random +import unittest + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + AutoencoderKL, + AutoPipelineForImage2Image, + EulerDiscreteScheduler, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPAGImg2ImgPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, +) +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineFromPipeTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) + + +enable_full_determinism() + + +class StableDiffusionXLPAGImg2ImgPipelineFastTests( + PipelineTesterMixin, + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineFromPipeTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, +): + pipeline_class = StableDiffusionXLPAGImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) - {"height", "width"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union( + {"add_text_embeds", "add_time_ids", "add_neg_time_ids"} + ) + + # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.get_dummy_components + def get_dummy_components( + self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False + ): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + time_cond_proj_dim=time_cond_proj_dim, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=72 if requires_aesthetics_score else 80, # 5 * 8 + 32 + cross_attention_dim=64 if not skip_first_text_encoder else 32, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=32, + image_size=224, + projection_dim=32, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + feature_extractor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder if not skip_first_text_encoder else None, + "tokenizer": tokenizer if not skip_first_text_encoder else None, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "requires_aesthetics_score": requires_aesthetics_score, + "image_encoder": image_encoder, + "feature_extractor": feature_extractor, + } + return components + + # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.StableDiffusionXLImg2ImgPipelineFastTests + # add `pag_scale` to the inputs + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "pag_scale": 3.0, + "output_type": "np", + "strength": 0.8, + } + return inputs + + def test_pag_disable_enable(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(requires_aesthetics_score=True) + + # base pipeline + pipe_sd = StableDiffusionXLImg2ImgPipeline(**components) + pipe_sd = pipe_sd.to(device) + pipe_sd.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["pag_scale"] + assert ( + "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." + out = pipe_sd(**inputs).images[0, -3:, -3:, -1] + + # pag disabled with pag_scale=0.0 + pipe_pag = self.pipeline_class(**components) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["pag_scale"] = 0.0 + out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + # pag enabled + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] + + assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 + assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3 + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + def test_pag_inference(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(requires_aesthetics_score=True) + + pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"]) + pipe_pag = pipe_pag.to(device) + pipe_pag.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe_pag(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == ( + 1, + 32, + 32, + 3, + ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}" + expected_slice = np.array( + [0.46703637, 0.4917526, 0.44394222, 0.6895079, 0.56251144, 0.45474228, 0.5957122, 0.6016377, 0.5276273] + ) + + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}" + + +@slow +@require_torch_gpu +class StableDiffusionXLPAGImg2ImgPipelineIntegrationTests(unittest.TestCase): + repo_id = "stabilityai/stable-diffusion-xl-base-1.0" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0): + img_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" + ) + + init_image = load_image(img_url) + + generator = torch.Generator(device=generator_device).manual_seed(seed) + inputs = { + "prompt": "a dog catching a frisbee in the jungle", + "generator": generator, + "image": init_image, + "strength": 0.8, + "num_inference_steps": 3, + "guidance_scale": guidance_scale, + "pag_scale": 3.0, + "output_type": "np", + } + return inputs + + def test_pag_cfg(self): + pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 1024, 1024, 3) + expected_slice = np.array( + [0.20301354, 0.21078318, 0.2021082, 0.20277798, 0.20681083, 0.19562206, 0.20121682, 0.21562952, 0.21277016] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" + + def test_pag_uncond(self): + pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device, guidance_scale=0.0) + image = pipeline(**inputs).images + + image_slice = image[0, -3:, -3:, -1].flatten() + assert image.shape == (1, 1024, 1024, 3) + expected_slice = np.array( + [0.21303111, 0.22188407, 0.2124992, 0.21365267, 0.18823743, 0.17569828, 0.21113116, 0.19419771, 0.18919235] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + ), f"output is different from expected, {image_slice.flatten()}" diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py index 92cfa8c85667..9385b1fe405a 100644 --- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py +++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py @@ -82,10 +82,10 @@ class StableDiffusionXLPAGInpaintPipelineFastTests( {"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"} ) + # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components def get_dummy_components( self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False ): - # copied from tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -221,7 +221,7 @@ def test_pag_disable_enable(self): del inputs["pag_scale"] assert ( "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters - ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}." + ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." out = pipe_sd(**inputs).images[0, -3:, -3:, -1] # pag disabled with pag_scale=0.0 diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py index 268fa052923a..768026fa5460 100644 --- a/tests/pipelines/test_pipelines_auto.py +++ b/tests/pipelines/test_pipelines_auto.py @@ -123,6 +123,21 @@ def test_kwargs_local_files_only(self): shutil.rmtree(tmpdirname.parent.parent) + def test_from_pretrained_text2img(self): + repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + pipe = AutoPipelineForText2Image.from_pretrained(repo) + assert pipe.__class__.__name__ == "StableDiffusionXLPipeline" + + controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet") + pipe_control = AutoPipelineForText2Image.from_pretrained(repo, controlnet=controlnet) + assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetPipeline" + + pipe_pag = AutoPipelineForText2Image.from_pretrained(repo, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGPipeline" + + pipe_control_pag = AutoPipelineForText2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True) + assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGPipeline" + def test_from_pipe_pag_text2img(self): # test from StableDiffusionXLPipeline pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") @@ -214,6 +229,42 @@ def test_from_pipe_pag_text2img(self): assert pipe.__class__.__name__ == "StableDiffusionXLControlNetPipeline" assert "controlnet" in pipe.components + def test_from_pretrained_img2img(self): + repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + pipe = AutoPipelineForImage2Image.from_pretrained(repo) + assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" + + pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + + def test_from_pipe_pag_img2img(self): + # test from tableDiffusionXLPAGImg2ImgPipeline + pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") + # - test `enable_pag` flag + pipe_pag = AutoPipelineForImage2Image.from_pipe(pipe, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + + pipe = AutoPipelineForImage2Image.from_pipe(pipe, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" + + # testing from StableDiffusionXLPAGImg2ImgPipeline + # - test `enable_pag` flag + pipe_pag = AutoPipelineForImage2Image.from_pipe(pipe_pag, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + + pipe = AutoPipelineForImage2Image.from_pipe(pipe_pag, enable_pag=False) + assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline" + + def test_from_pretrained_inpaint(self): + repo = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + pipe = AutoPipelineForInpainting.from_pretrained(repo) + assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline" + + pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True) + assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + def test_from_pipe_pag_inpaint(self): # test from tableDiffusionXLPAGInpaintPipeline pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe") @@ -242,10 +293,23 @@ def test_from_pipe_pag_new_task(self): # text2img pag -> inpaint pag pipe_pag_inpaint = AutoPipelineForInpainting.from_pipe(pipe_pag_text2img) assert pipe_pag_inpaint.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" + # text2img pag -> img2img pag + pipe_pag_img2img = AutoPipelineForImage2Image.from_pipe(pipe_pag_text2img) + assert pipe_pag_img2img.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" # inpaint pag -> text2img pag pipe_pag_text2img = AutoPipelineForText2Image.from_pipe(pipe_pag_inpaint) assert pipe_pag_text2img.__class__.__name__ == "StableDiffusionXLPAGPipeline" + # inpaint pag -> img2img pag + pipe_pag_img2img = AutoPipelineForImage2Image.from_pipe(pipe_pag_inpaint) + assert pipe_pag_img2img.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline" + + # img2img pag -> text2img pag + pipe_pag_text2img = AutoPipelineForText2Image.from_pipe(pipe_pag_img2img) + assert pipe_pag_text2img.__class__.__name__ == "StableDiffusionXLPAGPipeline" + # img2img pag -> inpaint pag + pipe_pag_inpaint = AutoPipelineForInpainting.from_pipe(pipe_pag_img2img) + assert pipe_pag_inpaint.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline" def test_from_pipe_controlnet_text2img(self): pipe = AutoPipelineForText2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe") From 1e79c5980c20165c043a8e36590bbc502f659794 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Jun 2024 09:51:39 +0200 Subject: [PATCH 28/44] remove guess model support from pag controlnet pipeline --- .../pag/pipeline_pag_controlnet_sd_xl.py | 74 +++++-------------- 1 file changed, 19 insertions(+), 55 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index f91d11fccaf7..e936ac290816 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -638,12 +638,7 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, guidance_scale=None, pag_scale=None, - guess_mode=None, ): - if guess_mode and pag_scale > 0 and guidance_scale > 1: - raise ValueError( - "guess_mode cannot work with PAG and guidance scale together. Please set either `pag_scale` or `guidance_scale`; or set `guess_mode` to False." - ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -1027,7 +1022,6 @@ def __call__( return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, original_size: Tuple[int, int] = None, @@ -1139,9 +1133,6 @@ def __call__( The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. - guess_mode (`bool`, *optional*, defaults to `False`): - The ControlNet encoder tries to recognize the content of the input image even if you remove all - prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): @@ -1238,7 +1229,6 @@ def __call__( callback_on_step_end_tensor_inputs, guidance_scale, pag_scale, - guess_mode, ) self._guidance_scale = guidance_scale @@ -1261,13 +1251,6 @@ def __call__( if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions - # 3.1 Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None @@ -1314,7 +1297,7 @@ def __call__( device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, + guess_mode=False, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): @@ -1330,7 +1313,7 @@ def __call__( device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, + guess_mode=False, ) images.append(image_) @@ -1411,22 +1394,21 @@ def __call__( else: negative_add_time_ids = add_time_ids - if not guess_mode: - images = image if isinstance(image, list) else [image] - for i, single_image in enumerate(images): - if self.do_classifier_free_guidance: - single_image = single_image.chunk(2)[0] + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] - if self.do_perturbed_attention_guidance: - single_image = self._prepare_perturbed_attention_guidance( - single_image, single_image, self.do_classifier_free_guidance - ) - elif self.do_classifier_free_guidance: - single_image = torch.cat([single_image] * 2) - single_image = single_image.to(device) - images[i] = single_image + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image - image = images if isinstance(image, list) else images[0] + image = images if isinstance(image, list) else images[0] if ip_adapter_image_embeds is not None: for i, image_embeds in enumerate(ip_adapter_image_embeds): @@ -1443,11 +1425,6 @@ def __call__( image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds - # for guess_mode, we do not need to apply guidance on controlnet inputs - if guess_mode: - controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance @@ -1468,9 +1445,8 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if not guess_mode: - controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = added_cond_kwargs + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1512,12 +1488,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode: - # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - else: - control_model_input = latent_model_input + control_model_input = latent_model_input if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -1533,18 +1504,11 @@ def __call__( encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, - guess_mode=guess_mode, + guess_mode=False, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) - if guess_mode and (self.do_classifier_free_guidance or self.do_perturbed_attention_guidance): - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - if ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds From 14b4ddd9bcdae7b2698524bec4c6c972eb382aec Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Jun 2024 09:57:57 +0200 Subject: [PATCH 29/44] noise_pred_uncond -> noise_pred_text --- src/diffusers/models/attention_processor.py | 1 - src/diffusers/pipelines/pag/pag_utils.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b41cde1c158c..88018e9a0b1e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3014,7 +3014,6 @@ def __call__( value = attn.to_v(hidden_states_ptb) - # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) hidden_states_ptb = value hidden_states_ptb = hidden_states_ptb.to(query.dtype) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 24e476676ddc..b4ae005b26a0 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -181,8 +181,8 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui + pag_scale * (noise_pred_text - noise_pred_perturb) ) else: - noise_pred_uncond, noise_pred_perturb = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + pag_scale * (noise_pred_uncond - noise_pred_perturb) + noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) + noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) return noise_pred def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): From 91c41e858874719093bd805823e3e2868bd9b2f4 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 19 Jun 2024 22:06:02 -1000 Subject: [PATCH 30/44] Apply suggestions from code review Co-authored-by: Sayak Paul --- src/diffusers/models/attention_processor.py | 22 +++++++-------------- src/diffusers/pipelines/pag/pag_utils.py | 2 +- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 88018e9a0b1e..196e464c117d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2926,12 +2926,14 @@ def __call__( class PAGIdentitySelfAttnProcessor2_0: r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + + PAG reference: https://arxiv.org/abs/2403.17377 """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError("PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, @@ -2940,12 +2942,7 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) + ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: @@ -3044,7 +3041,7 @@ class PAGCFGIdentitySelfAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, @@ -3053,12 +3050,7 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) + ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index b4ae005b26a0..2009024e4e47 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -148,7 +148,7 @@ def get_attn_index(module_name): def _get_pag_scale(self, t): r""" - Get the scale factor for the perturbed attention guidance. + Get the scale factor for the perturbed attention guidance at timestep `t`. """ if self.do_pag_adaptive_scaling: From b72ef1c3c822e57448a8c162b636456d0f83a5a2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Jun 2024 10:09:07 +0200 Subject: [PATCH 31/44] fix more --- src/diffusers/models/attention_processor.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 196e464c117d..9f96756d887b 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2927,13 +2927,14 @@ def __call__( class PAGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - PAG reference: https://arxiv.org/abs/2403.17377 """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -2943,7 +2944,6 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: - residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -3036,12 +3036,15 @@ def __call__( class PAGCFGIdentitySelfAttnProcessor2_0: r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + PAG reference: https://arxiv.org/abs/2403.17377 """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -3051,7 +3054,6 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: - residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) From d12b4a0c6baf226c2ab4f98cc49e2eb6cadf6957 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 24 Jun 2024 07:51:45 +0000 Subject: [PATCH 32/44] update docstring example --- src/diffusers/models/attention_processor.py | 5 +---- .../pag/pipeline_pag_controlnet_sd_xl.py | 16 ++++++++-------- .../pipelines/pag/pipeline_pag_sd_xl.py | 15 +++++++-------- .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 14 ++++++-------- .../pipelines/pag/pipeline_pag_sd_xl_inpaint.py | 17 +++++++++-------- 5 files changed, 31 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 83b360019e25..7b6b8afce2dd 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2646,10 +2646,7 @@ def __call__( if attn.group_norm is not None: hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) - value = attn.to_v(hidden_states_ptb) - - hidden_states_ptb = value - + hidden_states_ptb = attn.to_v(hidden_states_ptb) hidden_states_ptb = hidden_states_ptb.to(query.dtype) # linear proj diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index e936ac290816..21d445fba32c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -41,8 +41,6 @@ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder @@ -73,7 +71,7 @@ Examples: ```py >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL + >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, AutoencoderKL >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch @@ -95,8 +93,12 @@ ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 ... ) >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) - >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... torch_dtype=torch.float16, + ... enable_pag=True, ... ) >>> pipe.enable_model_cpu_offload() @@ -109,7 +111,7 @@ >>> # generate image >>> image = pipe( - ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, pag_scale=0.3 ... ).images[0] ``` """ @@ -923,8 +925,6 @@ def upcast_vae(self): ( AttnProcessor2_0, XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index fe7c9c59a5ce..58e39fc14073 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -35,8 +35,6 @@ from ...models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder @@ -73,15 +71,17 @@ Examples: ```py >>> import torch - >>> from diffusers import StableDiffusionXLPipeline + >>> from diffusers import AutoPipelineForText2Image - >>> pipe = StableDiffusionXLPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... enabe_pag=True, ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt).images[0] + >>> image = pipe(prompt, pag_scale=0.3).images[0] ``` """ @@ -730,6 +730,7 @@ def _get_add_time_ids( add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) @@ -738,8 +739,6 @@ def upcast_vae(self): ( AttnProcessor2_0, XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, FusedAttnProcessor2_0, ), ) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index a62787b486c0..0b110fa5ed4a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -36,8 +36,6 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder @@ -74,18 +72,20 @@ Examples: ```py >>> import torch - >>> from diffusers import StableDiffusionXLImg2ImgPipeline + >>> from diffusers import AutoPipelineForImage2Image >>> from diffusers.utils import load_image - >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16 + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-1.0", + ... torch_dtype=torch.float16, + ... enable_pag=True, ... ) >>> pipe = pipe.to("cuda") >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" >>> init_image = load_image(url).convert("RGB") >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt, image=init_image).images[0] + >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] ``` """ @@ -881,8 +881,6 @@ def upcast_vae(self): ( AttnProcessor2_0, XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index bc5176c2ee21..3e9b633364f9 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -36,8 +36,6 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder @@ -75,14 +73,14 @@ Examples: ```py >>> import torch - >>> from diffusers import StableDiffusionXLInpaintPipeline + >>> from diffusers import AutoPipelineForInpainting >>> from diffusers.utils import load_image - >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained( + >>> pipe = AutoPipelineForInpainting.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", ... torch_dtype=torch.float16, ... variant="fp16", - ... use_safetensors=True, + ... enable_pag=True, ... ) >>> pipe.to("cuda") @@ -94,7 +92,12 @@ >>> prompt = "A majestic tiger sitting on a bench" >>> image = pipe( - ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80 + ... prompt=prompt, + ... image=init_image, + ... mask_image=mask_image, + ... num_inference_steps=50, + ... strength=0.80, + ... pag_scale=0.3, ... ).images[0] ``` """ @@ -975,8 +978,6 @@ def upcast_vae(self): ( AttnProcessor2_0, XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need From 28e13013647e433e71cb9f72e7530fe67b858350 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 24 Jun 2024 08:17:45 +0000 Subject: [PATCH 33/44] add copied from --- .../pag/pipeline_pag_controlnet_sd_xl.py | 17 ++++++++++++----- .../pipelines/pag/pipeline_pag_sd_xl.py | 11 +++++++++++ .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 11 +++++++++++ .../pipelines/pag/pipeline_pag_sd_xl_inpaint.py | 12 ++++++++++++ 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 21d445fba32c..5ce62cb3a172 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -621,11 +621,13 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_inputs def check_inputs( self, prompt, prompt_2, image, + callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -634,13 +636,17 @@ def check_inputs( ip_adapter_image=None, ip_adapter_image_embeds=None, negative_pooled_prompt_embeds=None, - controlnet_conditioning_scale=None, - control_guidance_start=None, - control_guidance_end=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, - guidance_scale=None, - pag_scale=None, ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -1215,6 +1221,7 @@ def __call__( prompt, prompt_2, image, + None, negative_prompt, negative_prompt_2, prompt_embeds, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 58e39fc14073..22493ae4d1de 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -277,6 +277,7 @@ def __init__( self.set_pag_applied_layers(pag_applied_layers) + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, @@ -600,12 +601,14 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs def check_inputs( self, prompt, prompt_2, height, width, + callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -619,6 +622,12 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -712,6 +721,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): @@ -1020,6 +1030,7 @@ def __call__( prompt_2, height, width, + None, negative_prompt, negative_prompt_2, prompt_embeds, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 0b110fa5ed4a..ff227e47478d 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -551,12 +551,14 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs def check_inputs( self, prompt, prompt_2, strength, num_inference_steps, + callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -574,6 +576,11 @@ def check_inputs( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -635,6 +642,7 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: @@ -671,6 +679,7 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N return timesteps, num_inference_steps - t_start + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents def prepare_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): @@ -821,6 +830,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, @@ -1185,6 +1195,7 @@ def __call__( prompt_2, strength, num_inference_steps, + None, negative_prompt, negative_prompt_2, prompt_embeds, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index 3e9b633364f9..64aff497a594 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -641,6 +641,7 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.check_inputs def check_inputs( self, prompt, @@ -650,6 +651,7 @@ def check_inputs( height, width, strength, + callback_steps, output_type, negative_prompt=None, negative_prompt_2=None, @@ -666,6 +668,12 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -738,6 +746,7 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents def prepare_latents( self, batch_size, @@ -804,6 +813,7 @@ def prepare_latents( return outputs + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): dtype = image.dtype if self.vae.config.force_upcast: @@ -827,6 +837,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): @@ -1312,6 +1323,7 @@ def __call__( height, width, strength, + None, output_type, negative_prompt, negative_prompt_2, From 5653b2a89aabd50d0baed50f165c1be048af5be4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 25 Jun 2024 01:09:19 +0000 Subject: [PATCH 34/44] add doc --- docs/source/en/_toctree.yml | 4 + docs/source/en/api/pipelines/pag.md | 41 +++++ docs/source/en/using-diffusers/pag.md | 235 ++++++++++++++++++++++++++ 3 files changed, 280 insertions(+) create mode 100644 docs/source/en/api/pipelines/pag.md create mode 100644 docs/source/en/using-diffusers/pag.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c306e1eb99e7..ed1cf62f6124 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -81,6 +81,8 @@ title: Kandinsky - local: using-diffusers/ip_adapter title: IP-Adapter + - local: using-diffusers/pag + title: PAG - local: using-diffusers/controlnet title: ControlNet - local: using-diffusers/t2i_adapter @@ -302,6 +304,8 @@ title: Hunyuan-DiT - local: api/pipelines/i2vgenxl title: I2VGen-XL + - local: api/pipelines/pag + title: PAG - local: api/pipelines/pix2pix title: InstructPix2Pix - local: api/pipelines/kandinsky diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md new file mode 100644 index 000000000000..914db12ae8a2 --- /dev/null +++ b/docs/source/en/api/pipelines/pag.md @@ -0,0 +1,41 @@ + + +# Perturbed-Attention Guidance + +[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules + +PAG was introduced in [Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance](https://huggingface.co/papers/2403.17377) by Donghoon Ahn, Hyoungwon Cho, Jaewon Min, Wooseok Jang, Jungwoo Kim, SeonHwa Kim, Hyun Hee Park, Kyong Hwan Jin, Seungryong Kim + +The abstract from the paper is: + +*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.* + +## StableDiffusionXLPAGPipeline +[[autodoc]] StableDiffusionXLPAGPipeline + - all + - __call__ + +## StableDiffusionXLPAGImg2ImgPipeline +[[autodoc]] StableDiffusionXLPAGImg2ImgPipeline + - all + - __call__ + +## StableDiffusionXLPAGInpaintPipeline +[[autodoc]] StableDiffusionXLPAGInpaintPipeline + - all + - __call__ + +## StableDiffusionXLControlNetPAGPipeline +[[autodoc]] StableDiffusionXLControlNetPAGPipeline + - all + - __call__ diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md new file mode 100644 index 000000000000..ce975c253abf --- /dev/null +++ b/docs/source/en/using-diffusers/pag.md @@ -0,0 +1,235 @@ + + +# Perturbed-Attention Guidance + +[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules. PAG is designed to progressively enhance the structure of synthesized samples throughout the denoising process by considering the self-attention mechanisms' ability to capture structural information. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, and guiding the denoising process away from these degraded samples. + +This guide will show you how to use PAG for various tasks and use cases. + + +## General tasks + +You can apply PAG to the `StableDIffusionXLPpeline` for tasks like text-to-image, image-to-image, and inpainting. To enable PAG on a pipeline for a specfic task, simply load that pipeline using `AutoPipeline` API along with the `enable_pag=True` flag and the `pag_applied_layers` argument. + + + + +```py +from diffusers import AutoPipelineForText2Image +from diffusers.utils import load_image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + enable_pag=True, + pag_applied_layers =["mid"], + torch_dtype=torch.float16 +) +pipeline.enable_model_cpu_offload() +``` + +> [!TIP] +> `pag_applied_layers` argument allow you to specify which layers PAG is applied to. You can also use `set_pag_applied_layers` to update these layers after the pipeline has been created. + +In addition to the regular pipeline arguments such as `prompt` and `guidance_scale`, you will also need to pass a `pag_scale` to generate an image. + +```py +pag_scales = 3.0 +guidance_scale = 7.0 +prompt = "an insect robot preparing a delicious meal, anime style" + +generator = torch.Generator(device="cpu").manual_seed(0) +images = pipeline( + prompt=prompt, + num_inference_steps=25, + guidance_scale=guidance_scale, + generator=generator, + pag_scale=pag_scale, +).images +images[0] +``` + +
+
+ +
generated image without PAG
+
+
+ +
generated image with PAG
+
+
+ +
+ + +Similary, you can use PAG with image-to-image pipelines + +```py +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import load_image +import torch + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + enable_pag=True, + pag_applied_layers = ["mid"], + torch_dtype=torch.float16 +) +pipeline.enable_model_cpu_offload() + +pag_scales = 4.0 +guidance_scales = 7.0 + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" +init_image = load_image(url) +prompt = "a dog catching a frisbee in the jungle" + +generator = torch.Generator(device="cpu").manual_seed(0) +image = pipeline( + prompt, + image=init_image, + strength=0.8, + guidance_scale=guidance_scale, + pag_scale=pag_scale, + generator=generator).images[0] +``` + + + + +```py +from diffusers import AutoPipelineForInpainting +from diffusers.utils import load_image +import torch + +pipeline = AutoPipelineForInpainting.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + enable_pag=True, + torch_dtype=torch.float16 +) +pipeline.enable_model_cpu_offload() + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" +init_image = load_image(img_url).convert("RGB") +mask_image = load_image(mask_url).convert("RGB") + +prompt = "A majestic tiger sitting on a bench" + +pag_scales = 3.0 +guidance_scales = 7.5 + +generator = torch.Generator(device="cpu").manual_seed(1) +images = pipeline( + prompt=prompt, + image=init_image, + mask_image=mask_image, + strength=0.8, + num_inference_steps=50, + guidance_scale=guidance_scale, + generator=generator, + pag_scale=pag_scale, +).images +images[0] +``` + +## PAG with IP-adapter + +[IP-Adapter](https://hf.co/papers/2308.06721) is a popular tool that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-adapter loaded. + +```py +from diffusers import AutoPipelineForText2Image +from diffusers.utils import load_image +from transformers import CLIPVisionModelWithProjection +import torch + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="models/image_encoder", + torch_dtype=torch.float16 +) + +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + enable_pag=True, + torch_dtype=torch.float16 +).to("cuda") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.bin") + +pag_scales = 5.0 +ip_adapter_scales = 0.8 + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png") + +pipeline.set_ip_adapter_scale(ip_adapter_scale) +generator = torch.Generator(device="cpu").manual_seed(0) +images = pipeline( + prompt="a polar bear sitting in a chair drinking a milkshake", + ip_adapter_image=image, + negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + num_inference_steps=25, + guidance_scale=3.0, + generator=generator, + pag_scale=pag_scale, +).images +images[0] + +``` +PAG reduces artifacts and improve the overall compposition. +
+
+ +
generated image without PAG
+
+
+ +
generated image with PAG
+
+
+ + +## Configure parameters + +### pag_applied_layers + +`pag_applied_layers` argument allow you to specify which layers PAG is applied to. By default it will only applies to the mid blocks. It makes a significant difference on the output. You can use `set_pag_applied_layers` method to set different pag layers after the pipeline is created and find the optimal pag layer for your model. + +As an example, here is the images generated with `pag_layers = ["down.block_2"])` and `pag_layers = ["down.block_2", "up.block_1.attentions_0"]` + +```py +prompt = "an insect robot preparing a delicious meal, anime style" +pipeline.set_pag_applied_layers(pag_layers) +generator = torch.Generator(device="cpu").manual_seed(0) +images = pipeline( + prompt=prompt, + num_inference_steps=25, + guidance_scale=guidance_scale, + generator=generator, + pag_scale=pag_scale, +).images +images[0] +``` + +
+
+ +
down.block_2 + up.block1.attentions_0
+
+
+ +
down.block_2
+
+
From 17520f2218634296ff3381ee6d9ef5e7148f313e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 25 Jun 2024 02:05:04 +0000 Subject: [PATCH 35/44] up --- docs/source/en/_toctree.yml | 6 +++--- src/diffusers/pipelines/pag/__init__.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ed1cf62f6124..6e7cbe035fa7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -82,7 +82,7 @@ - local: using-diffusers/ip_adapter title: IP-Adapter - local: using-diffusers/pag - title: PAG + title: PAG - local: using-diffusers/controlnet title: ControlNet - local: using-diffusers/t2i_adapter @@ -304,8 +304,6 @@ title: Hunyuan-DiT - local: api/pipelines/i2vgenxl title: I2VGen-XL - - local: api/pipelines/pag - title: PAG - local: api/pipelines/pix2pix title: InstructPix2Pix - local: api/pipelines/kandinsky @@ -326,6 +324,8 @@ title: MultiDiffusion - local: api/pipelines/musicldm title: MusicLDM + - local: api/pipelines/pag + title: PAG - local: api/pipelines/paint_by_example title: Paint by Example - local: api/pipelines/pia diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py index 5695d75566fa..24c6fa06268b 100644 --- a/src/diffusers/pipelines/pag/__init__.py +++ b/src/diffusers/pipelines/pag/__init__.py @@ -23,10 +23,9 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] - _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] - _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: From 074a4f05f58db441cf5486ad6cf0ab34fcb075b2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 25 Jun 2024 02:39:34 +0000 Subject: [PATCH 36/44] fix copies --- src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index ff227e47478d..fb9938aa6a9d 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -732,8 +732,8 @@ def prepare_latents( init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=self.device, dtype=dtype) - latents_std = latents_std.to(device=self.device, dtype=dtype) + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: init_latents = self.vae.config.scaling_factor * init_latents From 18d8b0e895251b9823d9895096350c566d96386b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 25 Jun 2024 03:37:22 +0000 Subject: [PATCH 37/44] up --- docs/source/en/using-diffusers/pag.md | 82 ++++++++++++++++--- .../pag/pipeline_pag_controlnet_sd_xl.py | 2 - 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md index ce975c253abf..4481a9c3b1d5 100644 --- a/docs/source/en/using-diffusers/pag.md +++ b/docs/source/en/using-diffusers/pag.md @@ -41,22 +41,21 @@ pipeline.enable_model_cpu_offload() > [!TIP] > `pag_applied_layers` argument allow you to specify which layers PAG is applied to. You can also use `set_pag_applied_layers` to update these layers after the pipeline has been created. -In addition to the regular pipeline arguments such as `prompt` and `guidance_scale`, you will also need to pass a `pag_scale` to generate an image. +In addition to the regular pipeline arguments such as `prompt` and `guidance_scale`, you will also need to pass a `pag_scale` to generate an image. PAG is disabled when `pag_scale=0`. ```py -pag_scales = 3.0 -guidance_scale = 7.0 prompt = "an insect robot preparing a delicious meal, anime style" -generator = torch.Generator(device="cpu").manual_seed(0) -images = pipeline( - prompt=prompt, - num_inference_steps=25, - guidance_scale=guidance_scale, - generator=generator, - pag_scale=pag_scale, -).images -images[0] +for pag_scale in [0.0, 3.0]: + generator = torch.Generator(device="cpu").manual_seed(0) + images = pipeline( + prompt=prompt, + num_inference_steps=25, + guidance_scale=7.0, + generator=generator, + pag_scale=pag_scale, + ).images + images[0] ```
@@ -143,6 +142,63 @@ images = pipeline( ).images images[0] ``` + + +## PAG with ControlNet + +To use PAG with ControlNet, first create a `controlnet`. Then, pass the `controlne`t and PAG-related arguments to the `from_pretrained` method of the AutoPipeline for the specified task. + +```py +# pag doc example +from diffusers import AutoPipelineForText2Image, ControlNetModel +import torch + +controlnet = ControlNetModel.from_pretrained( + "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 +) + +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + controlnet=controlnet, + enable_pag=True, + pag_applied_layers = "mid", + torch_dtype=torch.float16 +) +pipeline.enable_model_cpu_offload() +``` + +You can use the pipeline similarly to how you normally use the controlnet pipelines. The only difference is that you can specify a `pag_scale` parameter. Note that PAG works well for unconditional generation. In this example, we will generate an image without using a prompt. + +```py +from diffusers.utils import load_image +canny_image = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_control_input.png" +) + +for pag_scale in [0.0, 3.0]: + generator = torch.Generator(device="cpu").manual_seed(1) + images = pipeline( + prompt="", + controlnet_conditioning_scale=controlnet_conditioning_scale, + image=canny_image, + num_inference_steps=50, + guidance_scale=0, + generator=generator, + pag_scale=pag_scale, + ).images + images[0] +``` + +
+
+ +
generated image without PAG
+
+
+ +
generated image with PAG
+
+
## PAG with IP-adapter @@ -188,7 +244,9 @@ images = pipeline( images[0] ``` + PAG reduces artifacts and improve the overall compposition. +
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 5ce62cb3a172..247fc900a7b0 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -1234,8 +1234,6 @@ def __call__( control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, - guidance_scale, - pag_scale, ) self._guidance_scale = guidance_scale From 0e337bfac08727f54bd0867e64396849f1bd2373 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 25 Jun 2024 03:48:51 +0000 Subject: [PATCH 38/44] up --- docs/source/en/using-diffusers/pag.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md index 4481a9c3b1d5..e57b2f2b66c3 100644 --- a/docs/source/en/using-diffusers/pag.md +++ b/docs/source/en/using-diffusers/pag.md @@ -19,7 +19,7 @@ This guide will show you how to use PAG for various tasks and use cases. ## General tasks -You can apply PAG to the `StableDIffusionXLPpeline` for tasks like text-to-image, image-to-image, and inpainting. To enable PAG on a pipeline for a specfic task, simply load that pipeline using `AutoPipeline` API along with the `enable_pag=True` flag and the `pag_applied_layers` argument. +You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument. @@ -39,9 +39,9 @@ pipeline.enable_model_cpu_offload() ``` > [!TIP] -> `pag_applied_layers` argument allow you to specify which layers PAG is applied to. You can also use `set_pag_applied_layers` to update these layers after the pipeline has been created. +> The `pag_applied_layers` argument allows you to specify which layers PAG is applied to. Additionally, you can use `set_pag_applied_layers` method to update these layers after the pipeline has been created. -In addition to the regular pipeline arguments such as `prompt` and `guidance_scale`, you will also need to pass a `pag_scale` to generate an image. PAG is disabled when `pag_scale=0`. +To generate an image, you will also need to pass a `pag_scale`. PAG is disabled when `pag_scale=0`. ```py prompt = "an insect robot preparing a delicious meal, anime style" @@ -146,10 +146,9 @@ images[0] ## PAG with ControlNet -To use PAG with ControlNet, first create a `controlnet`. Then, pass the `controlne`t and PAG-related arguments to the `from_pretrained` method of the AutoPipeline for the specified task. +To use PAG with ControlNet, first create a `controlnet`. Then, pass the `controlnet` and other PAG arguments to the `from_pretrained` method of the AutoPipeline for the specified task. ```py -# pag doc example from diffusers import AutoPipelineForText2Image, ControlNetModel import torch @@ -167,7 +166,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained( pipeline.enable_model_cpu_offload() ``` -You can use the pipeline similarly to how you normally use the controlnet pipelines. The only difference is that you can specify a `pag_scale` parameter. Note that PAG works well for unconditional generation. In this example, we will generate an image without using a prompt. +You can use the pipeline in the same way you normally use ControlNet pipelines, with the added option to specify a `pag_scale` parameter. Note that PAG works well for unconditional generation. In this example, we will generate an image without a prompt. ```py from diffusers.utils import load_image @@ -202,7 +201,7 @@ for pag_scale in [0.0, 3.0]: ## PAG with IP-adapter -[IP-Adapter](https://hf.co/papers/2308.06721) is a popular tool that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-adapter loaded. +[IP-Adapter](https://hf.co/papers/2308.06721) is a popular model that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-adapter loaded. ```py from diffusers import AutoPipelineForText2Image @@ -245,7 +244,7 @@ images[0] ``` -PAG reduces artifacts and improve the overall compposition. +PAG reduces artifacts and improves the overall compposition.
@@ -263,7 +262,7 @@ PAG reduces artifacts and improve the overall compposition. ### pag_applied_layers -`pag_applied_layers` argument allow you to specify which layers PAG is applied to. By default it will only applies to the mid blocks. It makes a significant difference on the output. You can use `set_pag_applied_layers` method to set different pag layers after the pipeline is created and find the optimal pag layer for your model. +The `pag_applied_layers` argument allows you to specify which layers PAG is applied to. By default, it applies only to the mid blocks. Changing this setting will significantly impact the output. You can use the `set_pag_applied_layer`s method to adjust the PAG layers after the pipeline is created, helping you find the optimal layers for your model. As an example, here is the images generated with `pag_layers = ["down.block_2"])` and `pag_layers = ["down.block_2", "up.block_1.attentions_0"]` From 434f63abf32949796275083b25c3d989768657fc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 25 Jun 2024 03:51:22 +0000 Subject: [PATCH 39/44] up --- docs/source/en/using-diffusers/pag.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md index e57b2f2b66c3..2421bd40ba7f 100644 --- a/docs/source/en/using-diffusers/pag.md +++ b/docs/source/en/using-diffusers/pag.md @@ -143,6 +143,7 @@ images = pipeline( images[0] ``` + ## PAG with ControlNet From 41b1ddc7006248b7525cf36fb3b3110c92ec4da9 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 25 Jun 2024 04:27:18 +0000 Subject: [PATCH 40/44] up --- docs/source/en/api/pipelines/pag.md | 4 ++-- docs/source/en/using-diffusers/pag.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md index 914db12ae8a2..7a201ed7ca49 100644 --- a/docs/source/en/api/pipelines/pag.md +++ b/docs/source/en/api/pipelines/pag.md @@ -1,4 +1,4 @@ -