From 5e6e351eb94878d285745ba948ed7afc50c90208 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 24 Nov 2023 02:25:17 +0000 Subject: [PATCH 01/16] Support IP-Adapter Plus --- src/diffusers/loaders/unet.py | 237 +++++++--- src/diffusers/models/__init__.py | 13 +- src/diffusers/models/embeddings.py | 156 +++++++ .../alt_diffusion/pipeline_alt_diffusion.py | 19 +- .../pipeline_alt_diffusion_img2img.py | 270 +++++++++--- .../animatediff/pipeline_animatediff.py | 185 ++++++-- .../controlnet/pipeline_controlnet.py | 329 ++++++++++---- .../controlnet/pipeline_controlnet_sd_xl.py | 395 +++++++++++++---- .../pipeline_stable_diffusion.py | 259 ++++++++--- .../pipeline_stable_diffusion_img2img.py | 271 +++++++++--- .../pipeline_stable_diffusion_inpaint.py | 380 ++++++++++++---- .../pipeline_stable_diffusion_xl.py | 303 +++++++++---- .../pipeline_stable_diffusion_xl_img2img.py | 335 ++++++++++---- .../pipeline_stable_diffusion_xl_inpaint.py | 407 +++++++++++++----- tests/models/test_models_unet_2d_condition.py | 391 ++++++++++++++--- .../test_ip_adapter_plus_stable_diffusion.py | 292 +++++++++++++ 16 files changed, 3382 insertions(+), 860 deletions(-) create mode 100644 tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 6c805672c9cd..71a9aa86b304 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -21,8 +21,11 @@ import torch.nn.functional as F from torch import nn -from ..models.embeddings import ImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.embeddings import ImageProjection, Resampler +from ..models.modeling_utils import ( + _LOW_CPU_MEM_USAGE_DEFAULT, + load_model_dict_into_meta, +) from ..utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -62,7 +65,11 @@ class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def load_attn_procs( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in @@ -128,7 +135,12 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict ``` """ from ..models.attention_processor import CustomDiffusionAttnProcessor - from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer + from ..models.lora import ( + LoRACompatibleConv, + LoRACompatibleLinear, + LoRAConv2dLayer, + LoRALinearLayer, + ) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) @@ -216,12 +228,17 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # fill attn processors lora_layers_list = [] - is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND + is_lora = ( + all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) + and not USE_PEFT_BACKEND + ) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: # correct keys - state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas) + state_dict, network_alphas = self.convert_state_dict_legacy_attn_format( + state_dict, network_alphas + ) if network_alphas is not None: network_alphas_keys = list(network_alphas.keys()) @@ -233,14 +250,18 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict all_keys = list(state_dict.keys()) for key in all_keys: value = state_dict.pop(key) - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join( + key.split(".")[-3:] + ) lora_grouped_dict[attn_processor_key][sub_key] = value # Create another `mapped_network_alphas` dictionary so that we can properly map them. if network_alphas is not None: for k in network_alphas_keys: if k.replace(".alpha", "") in key: - mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)}) + mapped_network_alphas.update( + {attn_processor_key: network_alphas.get(k)} + ) used_network_alphas_keys.add(k) if not is_network_alphas_none: @@ -289,7 +310,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict mapped_network_alphas.get(key), ) else: - raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") + raise ValueError( + f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module." + ) value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} lora_layers_list.append((attn_processor, lora)) @@ -297,7 +320,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict if low_cpu_mem_usage: device = next(iter(value_dict.values())).device dtype = next(iter(value_dict.values())).dtype - load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype) + load_model_dict_into_meta( + lora, value_dict, device=device, dtype=dtype + ) else: lora.load_state_dict(value_dict) @@ -309,20 +334,31 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict custom_diffusion_grouped_dict[key] = {} else: if "to_out" in key: - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + attn_processor_key, sub_key = ".".join( + key.split(".")[:-3] + ), ".".join(key.split(".")[-3:]) else: - attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) + attn_processor_key, sub_key = ".".join( + key.split(".")[:-2] + ), ".".join(key.split(".")[-2:]) custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in custom_diffusion_grouped_dict.items(): if len(value_dict) == 0: attn_processors[key] = CustomDiffusionAttnProcessor( - train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None + train_kv=False, + train_q_out=False, + hidden_size=None, + cross_attention_dim=None, ) else: - cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] + cross_attention_dim = value_dict[ + "to_k_custom_diffusion.weight" + ].shape[1] hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] - train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False + train_q_out = ( + True if "to_q_custom_diffusion.weight" in value_dict else False + ) attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=True, train_q_out=train_q_out, @@ -349,14 +385,22 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict if not USE_PEFT_BACKEND: if _pipeline is not None: for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + if isinstance(component, nn.Module) and hasattr( + component, "_hf_hook" + ): + is_model_cpu_offload = isinstance( + getattr(component, "_hf_hook"), CpuOffload + ) + is_sequential_cpu_offload = isinstance( + getattr(component, "_hf_hook"), AlignDevicesHook + ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + remove_hook_from_module( + component, recurse=is_sequential_cpu_offload + ) # only custom diffusion needs to set attn processors if is_custom_diffusion: @@ -377,16 +421,23 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): is_new_lora_format = all( - key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) + for key in state_dict.keys() ) if is_new_lora_format: # Strip the `"unet"` prefix. - is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) + is_text_encoder_present = any( + key.startswith(self.text_encoder_name) for key in state_dict.keys() + ) if is_text_encoder_present: warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." logger.warn(warn_message) unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] - state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} + state_dict = { + k.replace(f"{self.unet_name}.", ""): v + for k, v in state_dict.items() + if k in unet_keys + } # change processor format to 'pure' LoRACompatibleLinear format if any("processor" in k.split(".") for k in state_dict.keys()): @@ -394,12 +445,20 @@ def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): def format_to_lora_compatible(key): if "processor" not in key.split("."): return key - return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora") + return ( + key.replace(".processor", "") + .replace("to_out_lora", "to_out.0.lora") + .replace("_lora", ".lora") + ) - state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()} + state_dict = { + format_to_lora_compatible(k): v for k, v in state_dict.items() + } if network_alphas is not None: - network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()} + network_alphas = { + format_to_lora_compatible(k): v for k, v in network_alphas.items() + } return state_dict, network_alphas def save_attn_procs( @@ -450,14 +509,18 @@ def save_attn_procs( ) if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) return if save_function is None: if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + return safetensors.torch.save_file( + weights, filename, metadata={"format": "pt"} + ) else: save_function = torch.save @@ -467,7 +530,11 @@ def save_function(weights, filename): is_custom_diffusion = any( isinstance( x, - (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), + ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ), ) for (_, x) in self.attn_processors.items() ) @@ -496,13 +563,23 @@ def save_function(weights, filename): if weight_name is None: if safe_serialization: - weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE + weight_name = ( + CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE + if is_custom_diffusion + else LORA_WEIGHT_NAME_SAFE + ) else: - weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME + weight_name = ( + CUSTOM_DIFFUSION_WEIGHT_NAME + if is_custom_diffusion + else LORA_WEIGHT_NAME + ) # Save the model save_function(state_dict, os.path.join(save_directory, weight_name)) - logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + logger.info( + f"Model weights saved in {os.path.join(save_directory, weight_name)}" + ) def fuse_lora(self, lora_scale=1.0, safe_fusing=False): self.lora_scale = lora_scale @@ -568,7 +645,9 @@ def set_adapters( if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `set_adapters()`.") - adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + adapter_names = ( + [adapter_names] if isinstance(adapter_names, str) else adapter_names + ) if weights is None: weights = [1.0] * len(adapter_names) @@ -672,11 +751,22 @@ def _load_ip_adapter_weights(self, state_dict): IPAdapterAttnProcessor2_0, ) + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + num_image_text_embeds = 4 + else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] + # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 for name in self.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else self.config.cross_attention_dim + ) if name.startswith("mid_block"): hidden_size = self.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -687,20 +777,29 @@ def _load_ip_adapter_weights(self, state_dict): hidden_size = self.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = ( - AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + AttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else AttnProcessor ) attn_procs[name] = attn_processor_class() else: attn_processor_class = ( - IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) value_dict = {} for k, w in attn_procs[name].state_dict().items(): - value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + value_dict.update( + {f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]} + ) attn_procs[name].load_state_dict(value_dict) key_id += 2 @@ -708,28 +807,54 @@ def _load_ip_adapter_weights(self, state_dict): self.set_attn_processor(attn_procs) # create image projection layers. - clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] - cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] + cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + image_projection.to(dtype=self.dtype, device=self.device) - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 - ) - image_projection.to(dtype=self.dtype, device=self.device) - - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "image_embeds.weight": state_dict["image_proj"]["proj.weight"], - "image_embeds.bias": state_dict["image_proj"]["proj.bias"], - "norm.weight": state_dict["image_proj"]["norm.weight"], - "norm.bias": state_dict["image_proj"]["norm.bias"], - } - ) + # load image projection layer weights + image_proj_state_dict = {} + image_proj_state_dict.update( + { + "image_embeds.weight": state_dict["image_proj"]["proj.weight"], + "image_embeds.bias": state_dict["image_proj"]["proj.bias"], + "norm.weight": state_dict["image_proj"]["norm.weight"], + "norm.bias": state_dict["image_proj"]["norm.bias"], + } + ) - image_projection.load_state_dict(image_proj_state_dict) + image_projection.load_state_dict(image_proj_state_dict) - self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + else: + # IP-Adapter Plus + embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dims = state_dict["image_proj"]["latents"].shape[2] + num_heads = ( + state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 + ) + + image_projection = Resampler( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + num_heads=num_heads, + num_queries=num_image_text_embeds, + ) + + image_proj_state_dict = state_dict["image_proj"] + image_projection.load_state_dict(image_proj_state_dict) + + self.encoder_hid_proj = image_projection.to( + device=self.device, dtype=self.dtype + ) self.config.encoder_hid_dim_type = "ip_image_proj" delete_adapter_layers diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d45f56d43c32..e0e89daf0e7d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -14,7 +14,12 @@ from typing import TYPE_CHECKING -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, + is_flax_available, + is_torch_available, +) _import_structure = {} @@ -28,6 +33,7 @@ _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["embeddings"] = ["ImageProjection"] _import_structure["prior_transformer"] = ["PriorTransformer"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] @@ -54,6 +60,7 @@ from .consistency_decoder_vae import ConsistencyDecoderVAE from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel + from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder @@ -74,4 +81,6 @@ else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__ + ) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a377ae267411..45627f401515 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -790,3 +790,159 @@ def forward(self, caption, force_drop_ids=None): hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states + + +class PerceiverAttention(nn.Module): + """PerceiverAttention of IP-Adapter Plus. + + Args: + ---- + embed_dims (int): The feature dimension. + head_dims (int): The number of head channels. Defaults to 64. + num_heads (int): Parallel attention heads. Defaults to 16. + """ + + def __init__(self, + embed_dims: int, + head_dims=64, + num_heads: int = 16) -> None: + super().__init__() + self.head_dims = head_dims + self.num_heads = num_heads + inner_dim = head_dims * num_heads + + self.norm1 = nn.LayerNorm(embed_dims) + self.norm2 = nn.LayerNorm(embed_dims) + + self.to_q = nn.Linear(embed_dims, inner_dim, bias=False) + self.to_kv = nn.Linear(embed_dims, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, embed_dims, bias=False) + + def _reshape_tensor(self, x, heads) -> torch.Tensor: + """Reshape tensor.""" + bs, length, _ = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> + # (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> + # (bs*n_heads, length, dim_per_head) + return x.reshape(bs, heads, length, -1) + + def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + ---- + x (torch.Tensor): image features + shape (b, n1, D) + latents (torch.Tensor): latent features + shape (b, n2, D). + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, len_latents, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = self._reshape_tensor(q, self.num_heads) + k = self._reshape_tensor(k, self.num_heads) + v = self._reshape_tensor(v, self.num_heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.head_dims)) + # More stable with f16 than dividing afterwards + weight = (q * scale) @ (k * scale).transpose(-2, -1) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, len_latents, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + ---- + embed_dims (int): The feature dimension. Defaults to 768. + output_dims (int): The number of output channels, that is the same + number of the channels in the + `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): The number of hidden channels. Defaults to 1280. + depth (int): The number of blocks. Defaults to 8. + head_dims (int): The number of head channels. Defaults to 64. + num_heads (int): Parallel attention heads. Defaults to 16. + num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + head_dims: int = 64, + num_heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + + self.latents = nn.Parameter( + torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PerceiverAttention( + embed_dims=hidden_dims, + head_dims=head_dims, + num_heads=num_heads), + self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio), + ])) + + def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: + """Get feedforward network.""" + inner_dim = int(embed_dims * ffn_ratio) + return nn.Sequential( + nn.LayerNorm(embed_dims), + nn.Linear(embed_dims, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, embed_dims, bias=False), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + ---- + x (torch.Tensor): Input Tensor. + + Returns: + ------- + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 843e3b8b9410..f0ded5c2563d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -456,10 +456,19 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index b196ac4d3f69..c5706a1f126d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -19,12 +19,21 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + XLMRobertaTokenizer, +) from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -98,7 +107,10 @@ def preprocess(image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = [ + np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] + for i in image + ] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) @@ -111,7 +123,11 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + IPAdapterMixin, + LoraLoaderMixin, + FromSingleFileMixin, ): r""" Pipeline for text-guided image-to-image generation using Alt Diffusion. @@ -165,7 +181,10 @@ def __init__( ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -174,12 +193,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -187,7 +211,9 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -208,10 +234,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -224,7 +256,9 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -347,11 +381,13 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -360,17 +396,24 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -380,7 +423,9 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds + ) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -394,7 +439,9 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -430,7 +477,10 @@ def encode_prompt( return_tensors="pt", ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -445,10 +495,16 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -463,10 +519,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): @@ -474,10 +543,14 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -500,13 +573,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -522,16 +599,21 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -545,8 +627,12 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -571,7 +657,16 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + def prepare_latents( + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None, + ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" @@ -593,16 +688,23 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt elif isinstance(generator, list): init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents( + self.vae.encode(image[i : i + 1]), generator=generator[i] + ) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + init_latents = retrieve_latents( + self.vae.encode(image), generator=generator + ) init_latents = self.vae.config.scaling_factor * init_latents - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + if ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] == 0 + ): # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" @@ -610,10 +712,20 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "len(prompt) != len(image)", + "1.0.0", + deprecation_message, + standard_warn=False, + ) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + init_latents = torch.cat( + [init_latents] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] != 0 + ): raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) @@ -845,7 +957,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -865,7 +979,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -874,7 +990,9 @@ def __call__( # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -892,12 +1010,16 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -908,8 +1030,14 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual noise_pred = self.unet( @@ -925,10 +1053,14 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -938,20 +1070,28 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -961,7 +1101,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload all models self.maybe_free_model_hooks() @@ -969,4 +1111,6 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return AltDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 28dc220545dc..645856a72bb9 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -18,11 +18,21 @@ import numpy as np import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel +from ...models import ( + AutoencoderKL, + ImageProjection, + UNet2DConditionModel, + UNetMotionModel, +) from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( @@ -33,7 +43,13 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -77,7 +93,9 @@ class AnimateDiffPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] -class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): +class AnimateDiffPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin +): r""" Pipeline for text-to-video generation. @@ -210,11 +228,13 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -223,17 +243,24 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -243,7 +270,9 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds + ) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -257,7 +286,9 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -293,7 +324,10 @@ def encode_prompt( return_tensors="pt", ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -308,10 +342,16 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -327,10 +367,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents @@ -338,7 +391,9 @@ def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape - latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width + ) image = self.vae.decode(latents).sample video = ( @@ -425,13 +480,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -449,15 +508,20 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -472,8 +536,12 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -491,7 +559,16 @@ def check_inputs( # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + self, + batch_size, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + latents=None, ): shape = ( batch_size, @@ -507,7 +584,9 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) else: latents = latents.to(device) @@ -612,7 +691,13 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, ) # 2. Define call parameters @@ -631,7 +716,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -651,7 +738,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt + ) if do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -676,15 +765,21 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual noise_pred = self.unet( @@ -698,13 +793,19 @@ def __call__( # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -718,7 +819,9 @@ def __call__( if output_type == "pt": video = video_tensor else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = tensor2vid( + video_tensor, self.image_processor, output_type=output_type + ) # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 41e5e75f68e5..160a90332e34 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -20,11 +20,26 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import ( + AutoencoderKL, + ControlNetModel, + ImageProjection, + UNet2DConditionModel, +) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -92,7 +107,11 @@ class StableDiffusionControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + LoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -139,7 +158,12 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + controlnet: Union[ + ControlNetModel, + List[ControlNetModel], + Tuple[ControlNetModel], + MultiControlNetModel, + ], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, @@ -179,9 +203,13 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True + ) self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + vae_scale_factor=self.vae_scale_factor, + do_convert_rgb=True, + do_normalize=False, ) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -324,11 +352,13 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -337,17 +367,24 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -357,7 +394,9 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds + ) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -371,7 +410,9 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -407,7 +448,10 @@ def encode_prompt( return_tensors="pt", ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -422,10 +466,16 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -441,10 +491,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker @@ -453,10 +516,14 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -481,13 +548,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -505,14 +576,17 @@ def check_inputs( control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -527,8 +601,12 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -574,7 +652,9 @@ def check_inputs( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") + raise ValueError( + "A single batch of multiple conditionings are supported at the moment." + ) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -592,7 +672,9 @@ def check_inputs( and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + raise TypeError( + "For single controlnet: `controlnet_conditioning_scale` must be type `float`." + ) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled @@ -600,10 +682,12 @@ def check_inputs( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): + raise ValueError( + "A single batch of multiple conditionings are supported at the moment." + ) + elif isinstance(controlnet_conditioning_scale, list) and len( + controlnet_conditioning_scale + ) != len(self.controlnet.nets): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" @@ -634,16 +718,24 @@ def check_inputs( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + raise ValueError( + f"control guidance start: {start} can't be smaller than 0." + ) if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + raise ValueError( + f"control guidance end: {end} can't be larger than 1.0." + ) def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance( + image[0], PIL.Image.Image + ) + image_is_tensor_list = isinstance(image, list) and isinstance( + image[0], torch.Tensor + ) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( @@ -687,7 +779,9 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = self.control_image_processor.preprocess( + image, height=height, width=width + ).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -706,8 +800,23 @@ def prepare_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -715,7 +824,9 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) else: latents = latents.to(device) @@ -943,15 +1054,31 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + controlnet = ( + self.controlnet._orig_mod + if is_compiled_module(self.controlnet) + else self.controlnet + ) # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + if not isinstance(control_guidance_start, list) and isinstance( + control_guidance_end, list + ): + control_guidance_start = len(control_guidance_end) * [ + control_guidance_start + ] + elif not isinstance(control_guidance_end, list) and isinstance( + control_guidance_start, list + ): control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + elif not isinstance(control_guidance_start, list) and not isinstance( + control_guidance_end, list + ): + mult = ( + len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) + else 1 + ) control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -985,8 +1112,12 @@ def __call__( device = self._execution_device - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) and isinstance( + controlnet_conditioning_scale, float + ): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( + controlnet.nets + ) global_pool_conditions = ( controlnet.config.global_pool_conditions @@ -997,7 +1128,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -1017,7 +1150,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -1079,7 +1214,9 @@ def __call__( # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1088,7 +1225,9 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] @@ -1097,7 +1236,9 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + controlnet_keep.append( + keeps[0] if isinstance(controlnet, ControlNetModel) else keeps + ) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1108,24 +1249,39 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + is_unet_compiled and is_controlnet_compiled + ) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) + control_model_input = self.scheduler.scale_model_input( + control_model_input, t + ) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + cond_scale = [ + c * s + for c, s in zip( + controlnet_conditioning_scale, controlnet_keep[i] + ) + ] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): @@ -1146,8 +1302,13 @@ def __call__( # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + down_block_res_samples = [ + torch.cat([torch.zeros_like(d), d]) + for d in down_block_res_samples + ] + mid_block_res_sample = torch.cat( + [torch.zeros_like(mid_block_res_sample), mid_block_res_sample] + ) # predict the noise residual noise_pred = self.unet( @@ -1165,10 +1326,14 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1178,10 +1343,14 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1195,10 +1364,14 @@ def __call__( torch.cuda.empty_cache() if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -1208,7 +1381,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload all models self.maybe_free_model_hooks() @@ -1216,4 +1391,6 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4696781dce0c..7d4fa0ca3950 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -37,7 +37,12 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import ( + AutoencoderKL, + ControlNetModel, + ImageProjection, + UNet2DConditionModel, +) from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -182,7 +187,12 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + controlnet: Union[ + ControlNetModel, + List[ControlNetModel], + Tuple[ControlNetModel], + MultiControlNetModel, + ], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -207,18 +217,28 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True + ) self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + vae_scale_factor=self.vae_scale_factor, + do_convert_rgb=True, + do_normalize=False, + ) + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() ) - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -316,7 +336,9 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + if lora_scale is not None and isinstance( + self, StableDiffusionXLLoraLoaderMixin + ): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -340,9 +362,15 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + tokenizers = ( + [self.tokenizer, self.tokenizer_2] + if self.tokenizer is not None + else [self.tokenizer_2] + ) text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] + if self.text_encoder is not None + else [self.text_encoder_2] ) if prompt_embeds is None: @@ -352,7 +380,9 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + for prompt, tokenizer, text_encoder in zip( + prompts, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -365,18 +395,24 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder( + text_input_ids.to(device), output_hidden_states=True + ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -391,8 +427,14 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + zero_out_negative_prompt = ( + negative_prompt is None and self.config.force_zeros_for_empty_prompt + ) + if ( + do_classifier_free_guidance + and negative_prompt_embeds is None + and zero_out_negative_prompt + ): negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -400,9 +442,15 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + batch_size * [negative_prompt_2] + if isinstance(negative_prompt_2, str) + else negative_prompt_2 ) uncond_tokens: List[str] @@ -421,9 +469,13 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + for negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + negative_prompt = self.maybe_convert_prompt( + negative_prompt, tokenizer + ) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -447,34 +499,46 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + prompt_embeds = prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.unet.dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + pooled_prompt_embeds = pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -486,7 +550,12 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): @@ -496,10 +565,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -509,13 +591,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -537,14 +623,17 @@ def check_inputs( control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -564,10 +653,18 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + elif prompt_2 is not None and ( + not isinstance(prompt_2, str) and not isinstance(prompt_2, list) + ): + raise ValueError( + f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -628,7 +725,9 @@ def check_inputs( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") + raise ValueError( + "A single batch of multiple conditionings are supported at the moment." + ) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -646,7 +745,9 @@ def check_inputs( and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + raise TypeError( + "For single controlnet: `controlnet_conditioning_scale` must be type `float`." + ) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled @@ -654,10 +755,12 @@ def check_inputs( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): + raise ValueError( + "A single batch of multiple conditionings are supported at the moment." + ) + elif isinstance(controlnet_conditioning_scale, list) and len( + controlnet_conditioning_scale + ) != len(self.controlnet.nets): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" @@ -688,17 +791,25 @@ def check_inputs( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + raise ValueError( + f"control guidance start: {start} can't be smaller than 0." + ) if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + raise ValueError( + f"control guidance end: {end} can't be larger than 1.0." + ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance( + image[0], PIL.Image.Image + ) + image_is_tensor_list = isinstance(image, list) and isinstance( + image[0], torch.Tensor + ) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( @@ -743,7 +854,9 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = self.control_image_processor.preprocess( + image, height=height, width=width + ).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -762,8 +875,23 @@ def prepare_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -771,7 +899,9 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) else: latents = latents.to(device) @@ -781,12 +911,18 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + self, + original_size, + crops_coords_top_left, + target_size, + dtype, + text_encoder_projection_dim=None, ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1086,15 +1222,31 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + controlnet = ( + self.controlnet._orig_mod + if is_compiled_module(self.controlnet) + else self.controlnet + ) # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + if not isinstance(control_guidance_start, list) and isinstance( + control_guidance_end, list + ): + control_guidance_start = len(control_guidance_end) * [ + control_guidance_start + ] + elif not isinstance(control_guidance_end, list) and isinstance( + control_guidance_start, list + ): control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + elif not isinstance(control_guidance_start, list) and not isinstance( + control_guidance_end, list + ): + mult = ( + len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) + else 1 + ) control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -1132,8 +1284,12 @@ def __call__( device = self._execution_device - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) and isinstance( + controlnet_conditioning_scale, float + ): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( + controlnet.nets + ) global_pool_conditions = ( controlnet.config.global_pool_conditions @@ -1144,7 +1300,9 @@ def __call__( # 3.1 Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) ( prompt_embeds, @@ -1169,7 +1327,9 @@ def __call__( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -1231,7 +1391,9 @@ def __call__( # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1246,7 +1408,9 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + controlnet_keep.append( + keeps[0] if isinstance(controlnet, ControlNetModel) else keeps + ) # 7.2 Prepare added time ids & embeddings if isinstance(image, list): @@ -1282,12 +1446,16 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1298,19 +1466,32 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + is_unet_compiled and is_controlnet_compiled + ) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) + control_model_input = self.scheduler.scale_model_input( + control_model_input, t + ) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_added_cond_kwargs = { "text_embeds": add_text_embeds.chunk(2)[1], @@ -1322,7 +1503,12 @@ def __call__( controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + cond_scale = [ + c * s + for c, s in zip( + controlnet_conditioning_scale, controlnet_keep[i] + ) + ] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): @@ -1344,8 +1530,13 @@ def __call__( # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + down_block_res_samples = [ + torch.cat([torch.zeros_like(d), d]) + for d in down_block_res_samples + ] + mid_block_res_sample = torch.cat( + [torch.zeros_like(mid_block_res_sample), mid_block_res_sample] + ) if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds @@ -1366,10 +1557,14 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1379,10 +1574,14 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1391,17 +1590,25 @@ def __call__( # manually for max memory savings if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype + ) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + needs_upcasting = ( + self.vae.dtype == torch.float16 and self.vae.config.force_upcast + ) if needs_upcasting: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype + ) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a05abe00f2b1..f66b0ce9141f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -17,12 +17,22 @@ import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -61,17 +71,25 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) return noise_cfg class StableDiffusionPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + LoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -125,7 +143,10 @@ def __init__( ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -134,12 +155,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -147,7 +173,9 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -168,10 +196,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -184,7 +218,9 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -336,11 +372,13 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -349,17 +387,24 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -369,7 +414,9 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds + ) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -383,7 +430,9 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -419,7 +468,10 @@ def encode_prompt( return_tensors="pt", ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -434,10 +486,16 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -452,10 +510,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): @@ -463,10 +534,14 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -489,13 +564,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -512,15 +591,20 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -535,8 +619,12 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -552,8 +640,23 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -561,7 +664,9 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) else: latents = latents.to(device) @@ -800,7 +905,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -822,7 +929,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -847,12 +956,16 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -863,8 +976,14 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual noise_pred = self.unet( @@ -880,14 +999,22 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -897,20 +1024,28 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -920,7 +1055,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload all models self.maybe_free_model_hooks() @@ -928,4 +1065,6 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 029cd2b04839..797802b37043 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -19,12 +19,22 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -94,7 +104,10 @@ def preprocess(image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = [ + np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] + for i in image + ] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) @@ -106,7 +119,11 @@ def preprocess(image): class StableDiffusionImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + IPAdapterMixin, + LoraLoaderMixin, + FromSingleFileMixin, ): r""" Pipeline for text-guided image-to-image generation using Stable Diffusion. @@ -160,7 +177,10 @@ def __init__( ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -169,12 +189,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -182,7 +207,9 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -203,10 +230,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -219,7 +252,9 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -344,11 +379,13 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -357,17 +394,24 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -377,7 +421,9 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds + ) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -391,7 +437,9 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -427,7 +475,10 @@ def encode_prompt( return_tensors="pt", ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -442,10 +493,16 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -461,10 +518,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker @@ -473,10 +543,14 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -501,13 +575,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -523,16 +601,21 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -546,8 +629,12 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -572,7 +659,16 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + def prepare_latents( + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None, + ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" @@ -594,16 +690,23 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt elif isinstance(generator, list): init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents( + self.vae.encode(image[i : i + 1]), generator=generator[i] + ) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + init_latents = retrieve_latents( + self.vae.encode(image), generator=generator + ) init_latents = self.vae.config.scaling_factor * init_latents - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + if ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] == 0 + ): # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" @@ -611,10 +714,20 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "len(prompt) != len(image)", + "1.0.0", + deprecation_message, + standard_warn=False, + ) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + init_latents = torch.cat( + [init_latents] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] != 0 + ): raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) @@ -849,7 +962,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -869,7 +984,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -878,7 +995,9 @@ def __call__( # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -896,12 +1015,16 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -912,8 +1035,14 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual noise_pred = self.unet( @@ -929,10 +1058,14 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -942,20 +1075,28 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -965,7 +1106,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload all models self.maybe_free_model_hooks() @@ -973,4 +1116,6 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 09e50c60a807..737a2d6938b7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -19,15 +19,36 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + ImageProjection, + UNet2DConditionModel, +) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -37,7 +58,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): +def prepare_mask_and_masked_image( + image, mask, height, width, return_image: bool = False +): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -79,11 +102,15 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + raise TypeError( + f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" + ) # Batch single image if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + assert ( + image.shape[0] == 3 + ), "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask @@ -100,9 +127,15 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool else: mask = mask.unsqueeze(1) - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + assert ( + image.ndim == 4 and mask.ndim == 4 + ), "Image and Mask must have 4 dimensions" + assert ( + image.shape[-2:] == mask.shape[-2:] + ), "Image and Mask must have the same spatial dimensions" + assert ( + image.shape[0] == mask.shape[0] + ), "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: @@ -119,14 +152,18 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + raise TypeError( + f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" + ) else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width - image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [ + i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image + ] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -141,7 +178,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = np.concatenate( + [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 + ) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) @@ -170,7 +209,11 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionInpaintPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + IPAdapterMixin, + LoraLoaderMixin, + FromSingleFileMixin, ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. @@ -207,7 +250,13 @@ class StableDiffusionInpaintPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "mask", + "masked_image_latents", + ] def __init__( self, @@ -223,7 +272,10 @@ def __init__( ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -232,12 +284,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + if ( + hasattr(scheduler.config, "skip_prk_steps") + and scheduler.config.skip_prk_steps is False + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" @@ -246,7 +303,12 @@ def __init__( " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) - deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "skip_prk_steps not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) @@ -267,10 +329,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -283,14 +351,18 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: - logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") + logger.info( + f"You have loaded a UNet with {unet.config.in_channels} input channels which." + ) self.register_modules( vae=vae, @@ -305,7 +377,10 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + vae_scale_factor=self.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, ) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -415,11 +490,13 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -428,17 +505,24 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -448,7 +532,9 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds + ) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -462,7 +548,9 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -498,7 +586,10 @@ def encode_prompt( return_tensors="pt", ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -513,10 +604,16 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -532,10 +629,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker @@ -544,10 +654,14 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -560,13 +674,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -584,19 +702,26 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -611,8 +736,12 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -644,7 +773,12 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -664,14 +798,24 @@ def prepare_latents( image_latents = image else: image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + image_latents = image_latents.repeat( + batch_size // image_latents.shape[0], 1, 1, 1 + ) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + latents = ( + noise + if is_strength_max + else self.scheduler.add_noise(image_latents, noise, timestep) + ) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + latents = ( + latents * self.scheduler.init_noise_sigma + if is_strength_max + else latents + ) else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma @@ -689,19 +833,32 @@ def prepare_latents( def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents( + self.vae.encode(image[i : i + 1]), generator=generator[i] + ) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = retrieve_latents( + self.vae.encode(image), generator=generator + ) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload @@ -716,7 +873,9 @@ def prepare_mask_latents( if masked_image.shape[1] == 4: masked_image_latents = masked_image else: - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image( + masked_image, generator=generator + ) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -734,11 +893,15 @@ def prepare_mask_latents( f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( - torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + torch.cat([masked_image_latents] * 2) + if do_classifier_free_guidance + else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input @@ -1029,7 +1192,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -1049,7 +1214,9 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -1101,7 +1268,9 @@ def __call__( latents, noise = latents_outputs # 7. Prepare mask latent variables - mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width + ) if masked_image_latents is None: masked_image = init_image * (mask_condition < 0.5) @@ -1125,7 +1294,10 @@ def __call__( # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.unet.config.in_channels + ): raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" @@ -1142,12 +1314,16 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1158,13 +1334,21 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents], dim=1 + ) # predict the noise residual noise_pred = self.unet( @@ -1180,10 +1364,14 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if num_channels_unet == 4: init_latents_proper = image_latents if self.do_classifier_free_guidance: @@ -1197,7 +1385,9 @@ def __call__( init_latents_proper, noise, torch.tensor([noise_timestep]) ) - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + latents = ( + 1 - init_mask + ) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} @@ -1207,12 +1397,18 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) mask = callback_outputs.pop("mask", mask) - masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + masked_image_latents = callback_outputs.pop( + "masked_image_latents", masked_image_latents + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1221,15 +1417,27 @@ def __call__( if not output_type == "latent": condition_kwargs = {} if isinstance(self.vae, AsymmetricAutoencoderKL): - init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image = init_image.to( + device=device, dtype=masked_image_latents.dtype + ) init_image_condition = init_image.clone() init_image = self._encode_vae_image(init_image, generator=generator) - mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) - condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + mask_condition = mask_condition.to( + device=device, dtype=masked_image_latents.dtype + ) + condition_kwargs = { + "image": init_image_condition, + "mask": mask_condition, + } image = self.vae.decode( - latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + **condition_kwargs, )[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -1239,7 +1447,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload all models self.maybe_free_model_hooks() @@ -1247,4 +1457,6 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e32791693012..37f56c80ec15 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -31,7 +31,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -91,12 +91,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) return noise_cfg @@ -198,13 +202,19 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) if add_watermarker: self.watermark = StableDiffusionXLWatermarker() @@ -306,7 +316,9 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + if lora_scale is not None and isinstance( + self, StableDiffusionXLLoraLoaderMixin + ): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -330,9 +342,15 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + tokenizers = ( + [self.tokenizer, self.tokenizer_2] + if self.tokenizer is not None + else [self.tokenizer_2] + ) text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] + if self.text_encoder is not None + else [self.text_encoder_2] ) if prompt_embeds is None: @@ -342,7 +360,9 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + for prompt, tokenizer, text_encoder in zip( + prompts, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -355,18 +375,24 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder( + text_input_ids.to(device), output_hidden_states=True + ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -381,8 +407,14 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + zero_out_negative_prompt = ( + negative_prompt is None and self.config.force_zeros_for_empty_prompt + ) + if ( + do_classifier_free_guidance + and negative_prompt_embeds is None + and zero_out_negative_prompt + ): negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -390,9 +422,15 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + batch_size * [negative_prompt_2] + if isinstance(negative_prompt_2, str) + else negative_prompt_2 ) uncond_tokens: List[str] @@ -411,9 +449,13 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + for negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + negative_prompt = self.maybe_convert_prompt( + negative_prompt, tokenizer + ) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -437,34 +479,46 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + prompt_embeds = prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.unet.dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + pooled_prompt_embeds = pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -476,7 +530,12 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): @@ -486,10 +545,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -499,13 +571,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -526,16 +602,21 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -555,10 +636,18 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + elif prompt_2 is not None and ( + not isinstance(prompt_2, str) and not isinstance(prompt_2, list) + ): + raise ValueError( + f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -590,8 +679,23 @@ def check_inputs( ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -599,7 +703,9 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) else: latents = latents.to(device) @@ -608,12 +714,18 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + self, + original_size, + crops_coords_top_left, + target_size, + dtype, + text_encoder_projection_dim=None, ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -959,7 +1071,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) ( @@ -1031,21 +1145,29 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) # 8.1 Apply denoising_end if ( @@ -1060,13 +1182,17 @@ def __call__( - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1075,12 +1201,21 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( @@ -1096,14 +1231,22 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1113,16 +1256,24 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + negative_add_time_ids = callback_outputs.pop( + "negative_add_time_ids", negative_add_time_ids + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1133,13 +1284,19 @@ def __call__( if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + needs_upcasting = ( + self.vae.dtype == torch.float16 and self.vae.config.force_upcast + ) if needs_upcasting: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype + ) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index d40a037e67fe..b00aa4fdd0e9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,7 +32,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -95,12 +95,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) return noise_cfg @@ -216,12 +220,18 @@ def __init__( feature_extractor=feature_extractor, scheduler=scheduler, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) if add_watermarker: self.watermark = StableDiffusionXLWatermarker() @@ -324,7 +334,9 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + if lora_scale is not None and isinstance( + self, StableDiffusionXLLoraLoaderMixin + ): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -348,9 +360,15 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + tokenizers = ( + [self.tokenizer, self.tokenizer_2] + if self.tokenizer is not None + else [self.tokenizer_2] + ) text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] + if self.text_encoder is not None + else [self.text_encoder_2] ) if prompt_embeds is None: @@ -360,7 +378,9 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + for prompt, tokenizer, text_encoder in zip( + prompts, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -373,18 +393,24 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder( + text_input_ids.to(device), output_hidden_states=True + ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -399,8 +425,14 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + zero_out_negative_prompt = ( + negative_prompt is None and self.config.force_zeros_for_empty_prompt + ) + if ( + do_classifier_free_guidance + and negative_prompt_embeds is None + and zero_out_negative_prompt + ): negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -408,9 +440,15 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + batch_size * [negative_prompt_2] + if isinstance(negative_prompt_2, str) + else negative_prompt_2 ) uncond_tokens: List[str] @@ -429,9 +467,13 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + for negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + negative_prompt = self.maybe_convert_prompt( + negative_prompt, tokenizer + ) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -455,34 +497,46 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + prompt_embeds = prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.unet.dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + pooled_prompt_embeds = pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -494,7 +548,12 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -503,13 +562,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -528,7 +591,9 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) if num_inference_steps is None: raise ValueError("`num_inference_steps` cannot be None.") elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: @@ -536,14 +601,17 @@ def check_inputs( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -563,10 +631,18 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + elif prompt_2 is not None and ( + not isinstance(prompt_2, str) and not isinstance(prompt_2, list) + ): + raise ValueError( + f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -587,10 +663,14 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + def get_timesteps( + self, num_inference_steps, strength, device, denoising_start=None + ): # get the original timestep using init_timestep if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + init_timestep = min( + int(num_inference_steps * strength), num_inference_steps + ) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 @@ -624,7 +704,15 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N return timesteps, num_inference_steps - t_start def prepare_latents( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + self, + image, + timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator=None, + add_noise=True, ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( @@ -657,12 +745,16 @@ def prepare_latents( elif isinstance(generator, list): init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents( + self.vae.encode(image[i : i + 1]), generator=generator[i] + ) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + init_latents = retrieve_latents( + self.vae.encode(image), generator=generator + ) if self.vae.config.force_upcast: self.vae.to(dtype) @@ -670,11 +762,19 @@ def prepare_latents( init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + if ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] == 0 + ): # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + init_latents = torch.cat( + [init_latents] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] != 0 + ): raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) @@ -699,10 +799,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds def _get_add_time_ids( @@ -719,29 +832,38 @@ def _get_add_time_ids( text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_time_ids = list( + original_size + crops_coords_top_left + (aesthetic_score,) + ) add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + negative_original_size + + negative_crops_coords_top_left + + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + add_neg_time_ids = list( + negative_original_size + crops_coords_top_left + negative_target_size + ) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) + == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) + == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." @@ -1107,7 +1229,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) ( prompt_embeds, @@ -1196,8 +1320,12 @@ def denoising_value_valid(dnv): if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + add_neg_time_ids = add_neg_time_ids.repeat( + batch_size * num_images_per_prompt, 1 + ) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) @@ -1205,13 +1333,17 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) # 9. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) # 9.1 Apply denoising_end if ( @@ -1225,20 +1357,26 @@ def denoising_value_valid(dnv): f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) - elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + elif self.denoising_end is not None and denoising_value_valid( + self.denoising_end + ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1247,12 +1385,21 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( @@ -1268,14 +1415,22 @@ def denoising_value_valid(dnv): # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1285,16 +1440,24 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + add_neg_time_ids = callback_outputs.pop( + "add_neg_time_ids", add_neg_time_ids + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1305,13 +1468,19 @@ def denoising_value_valid(dnv): if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + needs_upcasting = ( + self.vae.dtype == torch.float16 and self.vae.config.force_upcast + ) if needs_upcasting: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype + ) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 3a9d068d60f3..560973ece087 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -33,7 +33,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -106,12 +106,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) return noise_cfg @@ -122,7 +126,9 @@ def mask_pil_to_torch(mask, height, width): if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = np.concatenate( + [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 + ) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) @@ -131,7 +137,9 @@ def mask_pil_to_torch(mask, height, width): return mask -def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): +def prepare_mask_and_masked_image( + image, mask, height, width, return_image: bool = False +): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -194,9 +202,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool else: mask = mask.unsqueeze(1) - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert ( + image.ndim == 4 and mask.ndim == 4 + ), "Image and Mask must have 4 dimensions" # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + assert ( + image.shape[0] == mask.shape[0] + ), "Image and Mask must have the same batch size" # Check image is in [-1, 1] # if image.min() < -1 or image.max() > 1: @@ -213,14 +225,18 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + raise TypeError( + f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" + ) else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width - image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [ + i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image + ] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -364,15 +380,24 @@ def __init__( feature_extractor=feature_extractor, scheduler=scheduler, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + vae_scale_factor=self.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, ) - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) if add_watermarker: self.watermark = StableDiffusionXLWatermarker() @@ -420,10 +445,23 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) + if isinstance(self.unet.encoder_hid_proj, ImageProjection): + # IP-Adapter + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # IP-Adapter Plus + image_embeds = self.image_encoder( + image, output_hidden_states=True + ).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt @@ -489,7 +527,9 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + if lora_scale is not None and isinstance( + self, StableDiffusionXLLoraLoaderMixin + ): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -513,9 +553,15 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + tokenizers = ( + [self.tokenizer, self.tokenizer_2] + if self.tokenizer is not None + else [self.tokenizer_2] + ) text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] + if self.text_encoder is not None + else [self.text_encoder_2] ) if prompt_embeds is None: @@ -525,7 +571,9 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + for prompt, tokenizer, text_encoder in zip( + prompts, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -538,18 +586,24 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + prompt_embeds = text_encoder( + text_input_ids.to(device), output_hidden_states=True + ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -564,8 +618,14 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + zero_out_negative_prompt = ( + negative_prompt is None and self.config.force_zeros_for_empty_prompt + ) + if ( + do_classifier_free_guidance + and negative_prompt_embeds is None + and zero_out_negative_prompt + ): negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -573,9 +633,15 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + batch_size * [negative_prompt_2] + if isinstance(negative_prompt_2, str) + else negative_prompt_2 ) uncond_tokens: List[str] @@ -594,9 +660,13 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + for negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + negative_prompt = self.maybe_convert_prompt( + negative_prompt, tokenizer + ) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -620,34 +690,46 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + prompt_embeds = prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder_2.dtype, device=device + ) else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.unet.dtype, device=device + ) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + pooled_prompt_embeds = pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -659,7 +741,12 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -668,13 +755,17 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -694,19 +785,26 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + if callback_steps is not None and ( + not isinstance(callback_steps, int) or callback_steps <= 0 + ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -726,10 +824,18 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + elif prompt_2 is not None and ( + not isinstance(prompt_2, str) and not isinstance(prompt_2, list) + ): + raise ValueError( + f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" + ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -767,7 +873,12 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -782,18 +893,30 @@ def prepare_latents( if image.shape[1] == 4: image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + image_latents = image_latents.repeat( + batch_size // image_latents.shape[0], 1, 1, 1 + ) elif return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + image_latents = image_latents.repeat( + batch_size // image_latents.shape[0], 1, 1, 1 + ) if latents is None and add_noise: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + latents = ( + noise + if is_strength_max + else self.scheduler.add_noise(image_latents, noise, timestep) + ) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + latents = ( + latents * self.scheduler.init_noise_sigma + if is_strength_max + else latents + ) elif add_noise: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma @@ -819,12 +942,16 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + retrieve_latents( + self.vae.encode(image[i : i + 1]), generator=generator[i] + ) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = retrieve_latents( + self.vae.encode(image), generator=generator + ) if self.vae.config.force_upcast: self.vae.to(dtype) @@ -835,7 +962,16 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload @@ -865,7 +1001,9 @@ def prepare_mask_latents( if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image( + masked_image, generator=generator + ) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: @@ -879,7 +1017,9 @@ def prepare_mask_latents( ) masked_image_latents = ( - torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + torch.cat([masked_image_latents] * 2) + if do_classifier_free_guidance + else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input @@ -888,10 +1028,14 @@ def prepare_mask_latents( return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + def get_timesteps( + self, num_inference_steps, strength, device, denoising_start=None + ): # get the original timestep using init_timestep if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + init_timestep = min( + int(num_inference_steps * strength), num_inference_steps + ) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 @@ -939,29 +1083,38 @@ def _get_add_time_ids( text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_time_ids = list( + original_size + crops_coords_top_left + (aesthetic_score,) + ) add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + negative_original_size + + negative_crops_coords_top_left + + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + add_neg_time_ids = list( + negative_original_size + crops_coords_top_left + negative_target_size + ) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) + == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) + == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." @@ -1348,7 +1501,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) ( @@ -1454,7 +1609,10 @@ def denoising_value_valid(dnv): # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.unet.config.in_channels + ): raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" @@ -1505,8 +1663,12 @@ def denoising_value_valid(dnv): if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + add_neg_time_ids = add_neg_time_ids.repeat( + batch_size * num_images_per_prompt, 1 + ) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) @@ -1514,13 +1676,17 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) # 11. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) if ( self.denoising_end is not None @@ -1533,20 +1699,26 @@ def denoising_value_valid(dnv): f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) - elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + elif self.denoising_end is not None and denoising_value_valid( + self.denoising_end + ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] # 11.1 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1555,16 +1727,27 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents], dim=1 + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( @@ -1580,14 +1763,22 @@ def denoising_value_valid(dnv): # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if num_channels_unet == 4: init_latents_proper = image_latents @@ -1602,7 +1793,9 @@ def denoising_value_valid(dnv): init_latents_proper, noise, torch.tensor([noise_timestep]) ) - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + latents = ( + 1 - init_mask + ) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} @@ -1612,18 +1805,28 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + add_neg_time_ids = callback_outputs.pop( + "add_neg_time_ids", add_neg_time_ids + ) mask = callback_outputs.pop("mask", mask) - masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + masked_image_latents = callback_outputs.pop( + "masked_image_latents", masked_image_latents + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1634,13 +1837,19 @@ def denoising_value_valid(dnv): if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + needs_upcasting = ( + self.vae.dtype == torch.float16 and self.vae.config.force_upcast + ) if needs_upcasting: self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents.to( + next(iter(self.vae.post_quant_conv.parameters())).dtype + ) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 06bf2685560d..fc349a182dc4 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -24,8 +24,11 @@ from pytest import mark from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor -from diffusers.models.embeddings import ImageProjection +from diffusers.models.attention_processor import ( + CustomDiffusionAttnProcessor, + IPAdapterAttnProcessor, +) +from diffusers.models.embeddings import ImageProjection, Resampler from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -52,7 +55,11 @@ def create_ip_adapter_state_dict(model): key_id = 1 for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else model.config.cross_attention_dim + ) if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -63,7 +70,9 @@ def create_ip_adapter_state_dict(model): hidden_size = model.config.block_out_channels[block_id] if cross_attention_dim is not None: sd = IPAdapterAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, ).state_dict() ip_cross_attn_state_dict.update( { @@ -77,7 +86,9 @@ def create_ip_adapter_state_dict(model): # "image_proj" (ImageProjection layer weights) cross_attention_dim = model.config["cross_attention_dim"] image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, num_image_text_embeds=4 + cross_attention_dim=cross_attention_dim, + image_embed_dim=cross_attention_dim, + num_image_text_embeds=4, ) ip_image_projection_state_dict = {} @@ -93,7 +104,68 @@ def create_ip_adapter_state_dict(model): del sd ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + ip_state_dict.update( + { + "image_proj": ip_image_projection_state_dict, + "ip_adapter": ip_cross_attn_state_dict, + } + ) + return ip_state_dict + + +def create_ip_adapter_plus_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else model.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + if cross_attention_dim is not None: + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"], + } + ) + + key_id += 2 + + # "image_proj" (ImageProjection layer weights) + cross_attention_dim = model.config["cross_attention_dim"] + image_projection = Resampler( + embed_dims=cross_attention_dim, + output_dims=cross_attention_dim, + hidden_dims=32, + num_heads=2, + num_queries=4, + ) + + ip_image_projection_state_dict = image_projection.state_dict() + + ip_state_dict = {} + ip_state_dict.update( + { + "image_proj": ip_image_projection_state_dict, + "ip_adapter": ip_cross_attn_state_dict, + } + ) return ip_state_dict @@ -104,7 +176,11 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): st = model.state_dict() for name, _ in model.attn_processors.items(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else model.config.cross_attention_dim + ) if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -120,8 +196,12 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): } if train_q_out: weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] - weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] - weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] + weights["to_out_custom_diffusion.0.weight"] = st[ + layer_name + ".to_out.0.weight" + ] + weights["to_out_custom_diffusion.0.bias"] = st[ + layer_name + ".to_out.0.bias" + ] if cross_attention_dim is not None: custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor( train_kv=train_kv, @@ -160,7 +240,11 @@ def dummy_input(self): time_step = torch.tensor([10]).to(torch_device) encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + return { + "sample": noise, + "timestep": time_step, + "encoder_hidden_states": encoder_hidden_states, + } @property def input_shape(self): @@ -196,7 +280,9 @@ def test_xformers_enable_works(self): model.enable_xformers_memory_efficient_attention() assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + model.mid_block.attentions[0] + .transformer_blocks[0] + .attn1.processor.__class__.__name__ == "XFormersAttnProcessor" ), "xformers is not enabled" @@ -239,7 +325,11 @@ def test_gradient_checkpointing(self): named_params = dict(model.named_parameters()) named_params_2 = dict(model_2.named_parameters()) for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + self.assertTrue( + torch_all_close( + param.grad.data, named_params_2[name].grad.data, atol=5e-5 + ) + ) def test_model_with_attention_head_dim_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -258,7 +348,9 @@ def test_model_with_attention_head_dim_tuple(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertEqual( + output.shape, expected_shape, "Input and output shapes do not match" + ) def test_model_with_use_linear_projection(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -277,7 +369,9 @@ def test_model_with_use_linear_projection(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertEqual( + output.shape, expected_shape, "Input and output shapes do not match" + ) def test_model_with_cross_attention_dim_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -296,7 +390,9 @@ def test_model_with_cross_attention_dim_tuple(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertEqual( + output.shape, expected_shape, "Input and output shapes do not match" + ) def test_model_with_simple_projection(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -306,7 +402,9 @@ def test_model_with_simple_projection(self): init_dict["class_embed_type"] = "simple_projection" init_dict["projection_class_embeddings_input_dim"] = sample_size - inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) + inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to( + torch_device + ) model = self.model_class(**init_dict) model.to(torch_device) @@ -320,7 +418,9 @@ def test_model_with_simple_projection(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertEqual( + output.shape, expected_shape, "Input and output shapes do not match" + ) def test_model_with_class_embeddings_concat(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -331,7 +431,9 @@ def test_model_with_class_embeddings_concat(self): init_dict["projection_class_embeddings_input_dim"] = sample_size init_dict["class_embeddings_concat"] = True - inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) + inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to( + torch_device + ) model = self.model_class(**init_dict) model.to(torch_device) @@ -345,7 +447,9 @@ def test_model_with_class_embeddings_concat(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertEqual( + output.shape, expected_shape, "Input and output shapes do not match" + ) def test_model_attention_slicing(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -434,13 +538,26 @@ def __init__(self, num): self.number = 0 self.counter = 0 - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + number=None, + ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) query = attn.to_q(hidden_states) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -503,21 +620,33 @@ def test_model_xattn_mask(self, mask_dtype): full_cond_out = model(**inputs_dict).sample assert full_cond_out is not None - keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) - full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample + keepall_mask = torch.ones( + *cond.shape[:-1], device=cond.device, dtype=mask_dtype + ) + full_cond_keepallmask_out = model( + **{**inputs_dict, "encoder_attention_mask": keepall_mask} + ).sample assert full_cond_keepallmask_out.allclose( full_cond_out, rtol=1e-05, atol=1e-05 ), "a 'keep all' mask should give the same result as no mask" trunc_cond = cond[:, :-1, :] - trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample + trunc_cond_out = model( + **{**inputs_dict, "encoder_hidden_states": trunc_cond} + ).sample assert not trunc_cond_out.allclose( full_cond_out, rtol=1e-05, atol=1e-05 ), "discarding the last token from our cond should change the result" batch, tokens, _ = cond.shape - mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) - masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample + mask_last = ( + (torch.arange(tokens) < tokens - 1) + .expand(batch, -1) + .to(cond.device, mask_dtype) + ) + masked_cond_out = model( + **{**inputs_dict, "encoder_attention_mask": mask_last} + ).sample assert masked_cond_out.allclose( trunc_cond_out, rtol=1e-05, atol=1e-05 ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" @@ -542,12 +671,24 @@ def test_model_xattn_padding(self): assert full_cond_out is not None batch, tokens, _ = cond.shape - keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool) - keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample - assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result" - - trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) - trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample + keeplast_mask = ( + (torch.arange(tokens) == tokens - 1) + .expand(batch, -1) + .to(cond.device, torch.bool) + ) + keeplast_out = model( + **{**inputs_dict, "encoder_attention_mask": keeplast_mask} + ).sample + assert not keeplast_out.allclose( + full_cond_out + ), "a 'keep last token' mask should change the result" + + trunc_mask = torch.zeros( + batch, tokens - 1, device=cond.device, dtype=torch.bool + ) + trunc_mask_out = model( + **{**inputs_dict, "encoder_attention_mask": trunc_mask} + ).sample assert trunc_mask_out.allclose( keeplast_out ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." @@ -564,7 +705,9 @@ def test_custom_diffusion_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + custom_diffusion_attn_procs = create_custom_diffusion_layers( + model, mock_weights=False + ) # make sure we can set a list of attention processors model.set_attn_processor(custom_diffusion_attn_procs) @@ -591,7 +734,9 @@ def test_custom_diffusion_save_load(self): with torch.no_grad(): old_sample = model(**inputs_dict).sample - custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + custom_diffusion_attn_procs = create_custom_diffusion_layers( + model, mock_weights=False + ) model.set_attn_processor(custom_diffusion_attn_procs) with torch.no_grad(): @@ -599,10 +744,16 @@ def test_custom_diffusion_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) + self.assertTrue( + os.path.isfile( + os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin") + ) + ) torch.manual_seed(0) new_model = self.model_class(**init_dict) - new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") + new_model.load_attn_procs( + tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin" + ) new_model.to(torch_device) with torch.no_grad(): @@ -626,7 +777,9 @@ def test_custom_diffusion_xformers_on_off(self): torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) - custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) + custom_diffusion_attn_procs = create_custom_diffusion_layers( + model, mock_weights=False + ) model.set_attn_processor(custom_diffusion_attn_procs) # default @@ -672,7 +825,9 @@ def test_asymmetrical_unet(self): expected_shape = inputs_dict["sample"].shape # Check if input and output shapes are the same - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertEqual( + output.shape, expected_shape, "Input and output shapes do not match" + ) def test_ip_adapter(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -688,22 +843,98 @@ def test_ip_adapter(self): # update inputs_dict for ip-adapter batch_size = inputs_dict["encoder_hidden_states"].shape[0] - image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to( + torch_device + ) inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} # make ip_adapter_1 and ip_adapter_2 ip_adapter_1 = create_ip_adapter_state_dict(model) - image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} - cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} + image_proj_state_dict_2 = { + k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items() + } + cross_attn_state_dict_2 = { + k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items() + } + ip_adapter_2 = {} + ip_adapter_2.update( + { + "image_proj": image_proj_state_dict_2, + "ip_adapter": cross_attn_state_dict_2, + } + ) + + # forward pass ip_adapter_1 + model._load_ip_adapter_weights(ip_adapter_1) + assert model.config.encoder_hid_dim_type == "ip_image_proj" + assert model.encoder_hid_proj is not None + assert model.down_blocks[0].attentions[0].transformer_blocks[ + 0 + ].attn2.processor.__class__.__name__ in ( + "IPAdapterAttnProcessor", + "IPAdapterAttnProcessor2_0", + ) + with torch.no_grad(): + sample2 = model(**inputs_dict).sample + + # forward pass with ip_adapter_2 + model._load_ip_adapter_weights(ip_adapter_2) + with torch.no_grad(): + sample3 = model(**inputs_dict).sample + + # forward pass with ip_adapter_1 again + model._load_ip_adapter_weights(ip_adapter_1) + with torch.no_grad(): + sample4 = model(**inputs_dict).sample + + assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4) + assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + + def test_ip_adapter_plus(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without ip-adapter + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + # update inputs_dict for ip-adapter + batch_size = inputs_dict["encoder_hidden_states"].shape[0] + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to( + torch_device + ) + inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + + # make ip_adapter_1 and ip_adapter_2 + ip_adapter_1 = create_ip_adapter_plus_state_dict(model) + + image_proj_state_dict_2 = { + k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items() + } + cross_attn_state_dict_2 = { + k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items() + } ip_adapter_2 = {} - ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) + ip_adapter_2.update( + { + "image_proj": image_proj_state_dict_2, + "ip_adapter": cross_attn_state_dict_2, + } + ) # forward pass ip_adapter_1 model._load_ip_adapter_weights(ip_adapter_1) assert model.config.encoder_hid_dim_type == "ip_image_proj" assert model.encoder_hid_proj is not None - assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( + assert model.down_blocks[0].attentions[0].transformer_blocks[ + 0 + ].attn2.processor.__class__.__name__ in ( "IPAdapterAttnProcessor", "IPAdapterAttnProcessor2_0", ) @@ -738,7 +969,11 @@ def tearDown(self): def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): dtype = torch.float16 if fp16 else torch.float32 - image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + image = ( + torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))) + .to(torch_device) + .to(dtype) + ) return image def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): @@ -765,7 +1000,9 @@ def test_set_attention_slice_auto(self): timestep = 1 with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + _ = unet( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -784,7 +1021,9 @@ def test_set_attention_slice_max(self): timestep = 1 with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + _ = unet( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -803,7 +1042,9 @@ def test_set_attention_slice_int(self): timestep = 1 with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + _ = unet( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -824,7 +1065,9 @@ def test_set_attention_slice_list(self): timestep = 1 with torch.no_grad(): - _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + _ = unet( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -832,7 +1075,11 @@ def test_set_attention_slice_list(self): def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): dtype = torch.float16 if fp16 else torch.float32 - hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + hidden_states = ( + torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))) + .to(torch_device) + .to(dtype) + ) return hidden_states @parameterized.expand( @@ -854,7 +1101,9 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + sample = model( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample assert sample.shape == latents.shape @@ -882,7 +1131,9 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + sample = model( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample assert sample.shape == latents.shape @@ -910,7 +1161,9 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + sample = model( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample assert sample.shape == latents.shape @@ -931,14 +1184,18 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): ) @require_torch_gpu def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True) + model = self.get_unet_model( + model_id="runwayml/stable-diffusion-v1-5", fp16=True + ) latents = self.get_latents(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + sample = model( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample assert sample.shape == latents.shape @@ -966,7 +1223,9 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + sample = model( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample assert sample.shape == (4, 4, 64, 64) @@ -987,14 +1246,18 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): ) @require_torch_gpu def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True) + model = self.get_unet_model( + model_id="runwayml/stable-diffusion-inpainting", fp16=True + ) latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + sample = model( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample assert sample.shape == (4, 4, 64, 64) @@ -1015,14 +1278,20 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): ) @require_torch_gpu def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + model = self.get_unet_model( + model_id="stabilityai/stable-diffusion-2", fp16=True + ) latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states( + seed, shape=(4, 77, 1024), fp16=True + ) timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + sample = model( + latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states + ).sample assert sample.shape == latents.shape diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py new file mode 100644 index 000000000000..15f9b9f2db52 --- /dev/null +++ b/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py @@ -0,0 +1,292 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, +) +from diffusers.utils import load_image +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_gpu, + slow, + torch_device, +) + + +enable_full_determinism() + + +class IPAdapterNightlyTestsMixin(unittest.TestCase): + dtype = torch.float16 + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_image_encoder(self, repo_id, subfolder): + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + repo_id, subfolder=subfolder, torch_dtype=self.dtype + ).to(torch_device) + return image_encoder + + def get_image_processor(self, repo_id): + image_processor = CLIPImageProcessor.from_pretrained(repo_id) + return image_processor + + def get_dummy_inputs( + self, for_image_to_image=False, for_inpainting=False, for_sdxl=False + ): + image = load_image( + "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" + ) + if for_sdxl: + image = image.resize((1024, 1024)) + + input_kwargs = { + "prompt": "best quality, high quality", + "negative_prompt": "monochrome, lowres, bad anatomy, worst quality, low quality", + "num_inference_steps": 5, + "generator": torch.Generator(device="cpu").manual_seed(33), + "ip_adapter_image": image, + "output_type": "np", + } + if for_image_to_image: + image = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg" + ) + ip_image = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png" + ) + + if for_sdxl: + image = image.resize((1024, 1024)) + ip_image = ip_image.resize((1024, 1024)) + + input_kwargs.update({"image": image, "ip_adapter_image": ip_image}) + + elif for_inpainting: + image = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png" + ) + mask = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png" + ) + ip_image = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png" + ) + + if for_sdxl: + image = image.resize((1024, 1024)) + mask = mask.resize((1024, 1024)) + ip_image = ip_image.resize((1024, 1024)) + + input_kwargs.update( + {"image": image, "mask_image": mask, "ip_adapter_image": ip_image} + ) + + return input_kwargs + + +@slow +@require_torch_gpu +class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): + def test_text_to_image(self): + image_encoder = self.get_image_encoder( + repo_id="h94/IP-Adapter", subfolder="models/image_encoder" + ) + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin" + ) + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array( + [0.3015, 0.2615, 0.2200, 0.2725, 0.2510, 0.2021, 0.2498, 0.2415, 0.2131] + ) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_image_to_image(self): + image_encoder = self.get_image_encoder( + repo_id="h94/IP-Adapter", subfolder="models/image_encoder" + ) + pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin" + ) + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array( + [0.3518, 0.2554, 0.2495, 0.2363, 0.1836, 0.3823, 0.1414, 0.1868, 0.5386] + ) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_inpainting(self): + image_encoder = self.get_image_encoder( + repo_id="h94/IP-Adapter", subfolder="models/image_encoder" + ) + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin" + ) + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array( + [0.2756, 0.2422, 0.2214, 0.2346, 0.2102, 0.2060, 0.2188, 0.2043, 0.1941] + ) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + +@slow +@require_torch_gpu +class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): + def test_text_to_image_sdxl(self): + image_encoder = self.get_image_encoder( + repo_id="h94/IP-Adapter", subfolder="models/image_encoder" + ) + feature_extractor = self.get_image_processor( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + ) + + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array( + [0.0587, 0.0567, 0.0454, 0.0537, 0.0553, 0.0518, 0.0494, 0.0535, 0.0497] + ) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_image_to_image_sdxl(self): + image_encoder = self.get_image_encoder( + repo_id="h94/IP-Adapter", subfolder="models/image_encoder" + ) + feature_extractor = self.get_image_processor( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + ) + + pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array( + [0.0711, 0.0700, 0.0734, 0.0758, 0.0742, 0.0688, 0.0751, 0.0827, 0.0851] + ) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_inpainting_sdxl(self): + image_encoder = self.get_image_encoder( + repo_id="h94/IP-Adapter", subfolder="models/image_encoder" + ) + feature_extractor = self.get_image_processor( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + ) + + pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + image_slice.tolist() + + expected_slice = np.array( + [0.1398, 0.1476, 0.1407, 0.1441, 0.1470, 0.1480, 0.1448, 0.1481, 0.1494] + ) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) From d39d7639523d5ef928fff15c1faa32471321a60f Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 24 Nov 2023 02:36:44 +0000 Subject: [PATCH 02/16] fix format --- src/diffusers/loaders/unet.py | 138 ++----- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/embeddings.py | 22 +- .../pipeline_alt_diffusion_img2img.py | 212 +++-------- .../animatediff/pipeline_animatediff.py | 133 ++----- .../controlnet/pipeline_controlnet.py | 253 ++++--------- .../controlnet/pipeline_controlnet_sd_xl.py | 332 +++++----------- .../pipeline_stable_diffusion.py | 193 +++------- .../pipeline_stable_diffusion_img2img.py | 212 +++-------- .../pipeline_stable_diffusion_inpaint.py | 286 ++++---------- .../pipeline_stable_diffusion_xl.py | 250 ++++-------- .../pipeline_stable_diffusion_xl_img2img.py | 298 ++++----------- .../pipeline_stable_diffusion_xl_inpaint.py | 357 +++++------------- tests/models/test_models_unet_2d_condition.py | 250 +++--------- .../test_ip_adapter_plus_stable_diffusion.py | 100 ++--- 15 files changed, 782 insertions(+), 2258 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 71a9aa86b304..cf818df8ce7c 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -228,17 +228,12 @@ def load_attn_procs( # fill attn processors lora_layers_list = [] - is_lora = ( - all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) - and not USE_PEFT_BACKEND - ) + is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: # correct keys - state_dict, network_alphas = self.convert_state_dict_legacy_attn_format( - state_dict, network_alphas - ) + state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas) if network_alphas is not None: network_alphas_keys = list(network_alphas.keys()) @@ -250,18 +245,14 @@ def load_attn_procs( all_keys = list(state_dict.keys()) for key in all_keys: value = state_dict.pop(key) - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join( - key.split(".")[-3:] - ) + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value # Create another `mapped_network_alphas` dictionary so that we can properly map them. if network_alphas is not None: for k in network_alphas_keys: if k.replace(".alpha", "") in key: - mapped_network_alphas.update( - {attn_processor_key: network_alphas.get(k)} - ) + mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)}) used_network_alphas_keys.add(k) if not is_network_alphas_none: @@ -310,9 +301,7 @@ def load_attn_procs( mapped_network_alphas.get(key), ) else: - raise ValueError( - f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module." - ) + raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} lora_layers_list.append((attn_processor, lora)) @@ -320,9 +309,7 @@ def load_attn_procs( if low_cpu_mem_usage: device = next(iter(value_dict.values())).device dtype = next(iter(value_dict.values())).dtype - load_model_dict_into_meta( - lora, value_dict, device=device, dtype=dtype - ) + load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype) else: lora.load_state_dict(value_dict) @@ -334,13 +321,9 @@ def load_attn_procs( custom_diffusion_grouped_dict[key] = {} else: if "to_out" in key: - attn_processor_key, sub_key = ".".join( - key.split(".")[:-3] - ), ".".join(key.split(".")[-3:]) + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) else: - attn_processor_key, sub_key = ".".join( - key.split(".")[:-2] - ), ".".join(key.split(".")[-2:]) + attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in custom_diffusion_grouped_dict.items(): @@ -352,13 +335,9 @@ def load_attn_procs( cross_attention_dim=None, ) else: - cross_attention_dim = value_dict[ - "to_k_custom_diffusion.weight" - ].shape[1] + cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] - train_q_out = ( - True if "to_q_custom_diffusion.weight" in value_dict else False - ) + train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=True, train_q_out=train_q_out, @@ -385,22 +364,14 @@ def load_attn_procs( if not USE_PEFT_BACKEND: if _pipeline is not None: for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr( - component, "_hf_hook" - ): - is_model_cpu_offload = isinstance( - getattr(component, "_hf_hook"), CpuOffload - ) - is_sequential_cpu_offload = isinstance( - getattr(component, "_hf_hook"), AlignDevicesHook - ) + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) - remove_hook_from_module( - component, recurse=is_sequential_cpu_offload - ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) # only custom diffusion needs to set attn processors if is_custom_diffusion: @@ -421,23 +392,16 @@ def load_attn_procs( def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): is_new_lora_format = all( - key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) - for key in state_dict.keys() + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ) if is_new_lora_format: # Strip the `"unet"` prefix. - is_text_encoder_present = any( - key.startswith(self.text_encoder_name) for key in state_dict.keys() - ) + is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) if is_text_encoder_present: warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." logger.warn(warn_message) unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] - state_dict = { - k.replace(f"{self.unet_name}.", ""): v - for k, v in state_dict.items() - if k in unet_keys - } + state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} # change processor format to 'pure' LoRACompatibleLinear format if any("processor" in k.split(".") for k in state_dict.keys()): @@ -445,20 +409,12 @@ def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): def format_to_lora_compatible(key): if "processor" not in key.split("."): return key - return ( - key.replace(".processor", "") - .replace("to_out_lora", "to_out.0.lora") - .replace("_lora", ".lora") - ) + return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora") - state_dict = { - format_to_lora_compatible(k): v for k, v in state_dict.items() - } + state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()} if network_alphas is not None: - network_alphas = { - format_to_lora_compatible(k): v for k, v in network_alphas.items() - } + network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()} return state_dict, network_alphas def save_attn_procs( @@ -509,18 +465,14 @@ def save_attn_procs( ) if os.path.isfile(save_directory): - logger.error( - f"Provided path ({save_directory}) should be a directory, not a file" - ) + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return if save_function is None: if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file( - weights, filename, metadata={"format": "pt"} - ) + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save @@ -563,23 +515,13 @@ def save_function(weights, filename): if weight_name is None: if safe_serialization: - weight_name = ( - CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE - if is_custom_diffusion - else LORA_WEIGHT_NAME_SAFE - ) + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE else: - weight_name = ( - CUSTOM_DIFFUSION_WEIGHT_NAME - if is_custom_diffusion - else LORA_WEIGHT_NAME - ) + weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME # Save the model save_function(state_dict, os.path.join(save_directory, weight_name)) - logger.info( - f"Model weights saved in {os.path.join(save_directory, weight_name)}" - ) + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") def fuse_lora(self, lora_scale=1.0, safe_fusing=False): self.lora_scale = lora_scale @@ -645,9 +587,7 @@ def set_adapters( if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `set_adapters()`.") - adapter_names = ( - [adapter_names] if isinstance(adapter_names, str) else adapter_names - ) + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names if weights is None: weights = [1.0] * len(adapter_names) @@ -762,11 +702,7 @@ def _load_ip_adapter_weights(self, state_dict): attn_procs = {} key_id = 1 for name in self.attn_processors.keys(): - cross_attention_dim = ( - None - if name.endswith("attn1.processor") - else self.config.cross_attention_dim - ) + cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = self.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -777,16 +713,12 @@ def _load_ip_adapter_weights(self, state_dict): hidden_size = self.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = ( - AttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else AttnProcessor + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() else: attn_processor_class = ( - IPAdapterAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else IPAdapterAttnProcessor + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( hidden_size=hidden_size, @@ -797,9 +729,7 @@ def _load_ip_adapter_weights(self, state_dict): value_dict = {} for k, w in attn_procs[name].state_dict().items(): - value_dict.update( - {f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]} - ) + value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) attn_procs[name].load_state_dict(value_dict) key_id += 2 @@ -837,9 +767,7 @@ def _load_ip_adapter_weights(self, state_dict): embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] hidden_dims = state_dict["image_proj"]["latents"].shape[2] - num_heads = ( - state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 - ) + num_heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 image_projection = Resampler( embed_dims=embed_dims, @@ -852,9 +780,7 @@ def _load_ip_adapter_weights(self, state_dict): image_proj_state_dict = state_dict["image_proj"] image_projection.load_state_dict(image_proj_state_dict) - self.encoder_hid_proj = image_projection.to( - device=self.device, dtype=self.dtype - ) + self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" delete_adapter_layers diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0e89daf0e7d..bf7cc7ddfe05 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -81,6 +81,4 @@ else: import sys - sys.modules[__name__] = _LazyModule( - __name__, globals()["__file__"], _import_structure, module_spec=__spec__ - ) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 45627f401515..9184cbab5543 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -802,10 +802,7 @@ class PerceiverAttention(nn.Module): num_heads (int): Parallel attention heads. Defaults to 16. """ - def __init__(self, - embed_dims: int, - head_dims=64, - num_heads: int = 16) -> None: + def __init__(self, embed_dims: int, head_dims=64, num_heads: int = 16) -> None: super().__init__() self.head_dims = head_dims self.num_heads = num_heads @@ -896,8 +893,7 @@ def __init__( ) -> None: super().__init__() - self.latents = nn.Parameter( - torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) self.proj_in = nn.Linear(embed_dims, hidden_dims) @@ -907,13 +903,13 @@ def __init__( self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( - nn.ModuleList([ - PerceiverAttention( - embed_dims=hidden_dims, - head_dims=head_dims, - num_heads=num_heads), - self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio), - ])) + nn.ModuleList( + [ + PerceiverAttention(embed_dims=hidden_dims, head_dims=head_dims, num_heads=num_heads), + self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio), + ] + ) + ) def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: """Get feedforward network.""" diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index c5706a1f126d..7069d4534a9e 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -107,10 +107,7 @@ def preprocess(image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - image = [ - np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] - for i in image - ] + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) @@ -181,10 +178,7 @@ def __init__( ): super().__init__() - if ( - hasattr(scheduler.config, "steps_offset") - and scheduler.config.steps_offset != 1 - ): + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -193,17 +187,12 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if ( - hasattr(scheduler.config, "clip_sample") - and scheduler.config.clip_sample is True - ): + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -211,9 +200,7 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate( - "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -234,16 +221,10 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr( - unet.config, "_diffusers_version" - ) and version.parse( + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse( - "0.9.0.dev0" - ) - is_unet_sample_size_less_64 = ( - hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - ) + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -256,9 +237,7 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate( - "sample_size<64", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -381,13 +360,11 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -396,18 +373,13 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( @@ -423,9 +395,7 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -439,9 +409,7 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -477,10 +445,7 @@ def encode_prompt( return_tensors="pt", ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -495,16 +460,10 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -526,16 +485,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): @@ -543,14 +498,10 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess( - image, output_type="pil" - ) + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor( - feature_extractor_input, return_tensors="pt" - ).to(device) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -573,17 +524,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -599,21 +546,16 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError( - f"The value of strength should in [0.0, 1.0] but is {strength}" - ) + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -627,12 +569,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -688,23 +626,16 @@ def prepare_latents( elif isinstance(generator, list): init_latents = [ - retrieve_latents( - self.vae.encode(image[i : i + 1]), generator=generator[i] - ) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents( - self.vae.encode(image), generator=generator - ) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents - if ( - batch_size > init_latents.shape[0] - and batch_size % init_latents.shape[0] == 0 - ): + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" @@ -719,13 +650,8 @@ def prepare_latents( standard_warn=False, ) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat( - [init_latents] * additional_image_per_prompt, dim=0 - ) - elif ( - batch_size > init_latents.shape[0] - and batch_size % init_latents.shape[0] != 0 - ): + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) @@ -957,9 +883,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -979,9 +903,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -990,9 +912,7 @@ def __call__( # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps, strength, device - ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -1010,16 +930,12 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1030,14 +946,8 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( @@ -1053,14 +963,10 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1070,14 +976,10 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1089,9 +991,7 @@ def __call__( return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -1101,9 +1001,7 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess( - image, output_type=output_type, do_denormalize=do_denormalize - ) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() @@ -1111,6 +1009,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return AltDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 645856a72bb9..5b1c34fdf831 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -93,9 +93,7 @@ class AnimateDiffPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] -class AnimateDiffPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin -): +class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -228,13 +226,11 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -243,18 +239,13 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( @@ -270,9 +261,7 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -286,9 +275,7 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -324,10 +311,7 @@ def encode_prompt( return_tensors="pt", ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -342,16 +326,10 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -374,16 +352,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents @@ -391,9 +365,7 @@ def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape - latents = latents.permute(0, 2, 1, 3, 4).reshape( - batch_size * num_frames, channels, height, width - ) + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = ( @@ -480,17 +452,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -508,20 +476,15 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -536,12 +499,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -584,9 +543,7 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) @@ -716,9 +673,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) - if cross_attention_kwargs is not None - else None + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -738,9 +693,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_videos_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) if do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -765,21 +718,15 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( @@ -793,19 +740,13 @@ def __call__( # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -819,9 +760,7 @@ def __call__( if output_type == "pt": video = video_tensor else: - video = tensor2vid( - video_tensor, self.image_processor, output_type=output_type - ) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 160a90332e34..95093b277439 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -203,9 +203,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True - ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, @@ -352,13 +350,11 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -367,18 +363,13 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( @@ -394,9 +385,7 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -410,9 +399,7 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -448,10 +435,7 @@ def encode_prompt( return_tensors="pt", ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -466,16 +450,10 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -498,16 +476,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker @@ -516,14 +490,10 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess( - image, output_type="pil" - ) + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor( - feature_extractor_input, return_tensors="pt" - ).to(device) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -548,17 +518,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -576,17 +542,14 @@ def check_inputs( control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -601,12 +564,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -652,9 +611,7 @@ def check_inputs( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError( - "A single batch of multiple conditionings are supported at the moment." - ) + raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -672,9 +629,7 @@ def check_inputs( and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): - raise TypeError( - "For single controlnet: `controlnet_conditioning_scale` must be type `float`." - ) + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled @@ -682,12 +637,10 @@ def check_inputs( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError( - "A single batch of multiple conditionings are supported at the moment." - ) - elif isinstance(controlnet_conditioning_scale, list) and len( - controlnet_conditioning_scale - ) != len(self.controlnet.nets): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" @@ -718,24 +671,16 @@ def check_inputs( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: - raise ValueError( - f"control guidance start: {start} can't be smaller than 0." - ) + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: - raise ValueError( - f"control guidance end: {end} can't be larger than 1.0." - ) + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance( - image[0], PIL.Image.Image - ) - image_is_tensor_list = isinstance(image, list) and isinstance( - image[0], torch.Tensor - ) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( @@ -779,9 +724,7 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess( - image, height=height, width=width - ).to(dtype=torch.float32) + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -824,9 +767,7 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) @@ -1054,31 +995,15 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - controlnet = ( - self.controlnet._orig_mod - if is_compiled_module(self.controlnet) - else self.controlnet - ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance( - control_guidance_end, list - ): - control_guidance_start = len(control_guidance_end) * [ - control_guidance_start - ] - elif not isinstance(control_guidance_end, list) and isinstance( - control_guidance_start, list - ): + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance( - control_guidance_end, list - ): - mult = ( - len(controlnet.nets) - if isinstance(controlnet, MultiControlNetModel) - else 1 - ) + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -1112,12 +1037,8 @@ def __call__( device = self._execution_device - if isinstance(controlnet, MultiControlNetModel) and isinstance( - controlnet_conditioning_scale, float - ): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( - controlnet.nets - ) + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions @@ -1128,9 +1049,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -1150,9 +1069,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -1214,9 +1131,7 @@ def __call__( # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1225,9 +1140,7 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] @@ -1236,9 +1149,7 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append( - keeps[0] if isinstance(controlnet, ControlNetModel) else keeps - ) + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1249,39 +1160,24 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if ( - is_unet_compiled and is_controlnet_compiled - ) and is_torch_higher_equal_2_1: + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents - control_model_input = self.scheduler.scale_model_input( - control_model_input, t - ) + control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): - cond_scale = [ - c * s - for c, s in zip( - controlnet_conditioning_scale, controlnet_keep[i] - ) - ] + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): @@ -1302,13 +1198,8 @@ def __call__( # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [ - torch.cat([torch.zeros_like(d), d]) - for d in down_block_res_samples - ] - mid_block_res_sample = torch.cat( - [torch.zeros_like(mid_block_res_sample), mid_block_res_sample] - ) + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( @@ -1326,14 +1217,10 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1343,14 +1230,10 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1369,9 +1252,7 @@ def __call__( return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -1381,9 +1262,7 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess( - image, output_type=output_type, do_denormalize=do_denormalize - ) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() @@ -1391,6 +1270,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 7d4fa0ca3950..9ebc35772b93 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -217,28 +217,20 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True - ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False, ) - add_watermarker = ( - add_watermarker - if add_watermarker is not None - else is_invisible_watermark_available() - ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - self.register_to_config( - force_zeros_for_empty_prompt=force_zeros_for_empty_prompt - ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -336,9 +328,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance( - self, StableDiffusionXLLoraLoaderMixin - ): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -362,15 +352,9 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = ( - [self.tokenizer, self.tokenizer_2] - if self.tokenizer is not None - else [self.tokenizer_2] - ) + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( - [self.text_encoder, self.text_encoder_2] - if self.text_encoder is not None - else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: @@ -380,9 +364,7 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip( - prompts, tokenizers, text_encoders - ): + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -395,24 +377,18 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode( - untruncated_ids[:, tokenizer.model_max_length - 1 : -1] - ) + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder( - text_input_ids.to(device), output_hidden_states=True - ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -427,14 +403,8 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = ( - negative_prompt is None and self.config.force_zeros_for_empty_prompt - ) - if ( - do_classifier_free_guidance - and negative_prompt_embeds is None - and zero_out_negative_prompt - ): + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -442,15 +412,9 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = ( - batch_size * [negative_prompt] - if isinstance(negative_prompt, str) - else negative_prompt - ) + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( - batch_size * [negative_prompt_2] - if isinstance(negative_prompt_2, str) - else negative_prompt_2 + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] @@ -469,13 +433,9 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip( - uncond_tokens, tokenizers, text_encoders - ): + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt( - negative_prompt, tokenizer - ) + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -499,46 +459,34 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.unet.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -572,16 +520,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -591,17 +535,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -623,17 +563,14 @@ def check_inputs( control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -653,18 +590,10 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - elif prompt_2 is not None and ( - not isinstance(prompt_2, str) and not isinstance(prompt_2, list) - ): - raise ValueError( - f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -725,9 +654,7 @@ def check_inputs( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError( - "A single batch of multiple conditionings are supported at the moment." - ) + raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -745,9 +672,7 @@ def check_inputs( and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): - raise TypeError( - "For single controlnet: `controlnet_conditioning_scale` must be type `float`." - ) + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled @@ -755,12 +680,10 @@ def check_inputs( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError( - "A single batch of multiple conditionings are supported at the moment." - ) - elif isinstance(controlnet_conditioning_scale, list) and len( - controlnet_conditioning_scale - ) != len(self.controlnet.nets): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" @@ -791,25 +714,17 @@ def check_inputs( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: - raise ValueError( - f"control guidance start: {start} can't be smaller than 0." - ) + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: - raise ValueError( - f"control guidance end: {end} can't be larger than 1.0." - ) + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance( - image[0], PIL.Image.Image - ) - image_is_tensor_list = isinstance(image, list) and isinstance( - image[0], torch.Tensor - ) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( @@ -854,9 +769,7 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess( - image, height=height, width=width - ).to(dtype=torch.float32) + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -899,9 +812,7 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) @@ -921,8 +832,7 @@ def _get_add_time_ids( add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) - + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1222,31 +1132,15 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - controlnet = ( - self.controlnet._orig_mod - if is_compiled_module(self.controlnet) - else self.controlnet - ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance( - control_guidance_end, list - ): - control_guidance_start = len(control_guidance_end) * [ - control_guidance_start - ] - elif not isinstance(control_guidance_end, list) and isinstance( - control_guidance_start, list - ): + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance( - control_guidance_end, list - ): - mult = ( - len(controlnet.nets) - if isinstance(controlnet, MultiControlNetModel) - else 1 - ) + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -1284,12 +1178,8 @@ def __call__( device = self._execution_device - if isinstance(controlnet, MultiControlNetModel) and isinstance( - controlnet_conditioning_scale, float - ): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( - controlnet.nets - ) + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions @@ -1300,9 +1190,7 @@ def __call__( # 3.1 Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, @@ -1327,9 +1215,7 @@ def __call__( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -1391,9 +1277,7 @@ def __call__( # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1408,9 +1292,7 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append( - keeps[0] if isinstance(controlnet, ControlNetModel) else keeps - ) + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings if isinstance(image, list): @@ -1446,16 +1328,12 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, add_text_embeds], dim=0 - ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat( - batch_size * num_images_per_prompt, 1 - ) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1466,19 +1344,11 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if ( - is_unet_compiled and is_controlnet_compiled - ) and is_torch_higher_equal_2_1: + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = { "text_embeds": add_text_embeds, @@ -1489,9 +1359,7 @@ def __call__( if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents - control_model_input = self.scheduler.scale_model_input( - control_model_input, t - ) + control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_added_cond_kwargs = { "text_embeds": add_text_embeds.chunk(2)[1], @@ -1503,12 +1371,7 @@ def __call__( controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): - cond_scale = [ - c * s - for c, s in zip( - controlnet_conditioning_scale, controlnet_keep[i] - ) - ] + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): @@ -1530,13 +1393,8 @@ def __call__( # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [ - torch.cat([torch.zeros_like(d), d]) - for d in down_block_res_samples - ] - mid_block_res_sample = torch.cat( - [torch.zeros_like(mid_block_res_sample), mid_block_res_sample] - ) + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds @@ -1557,14 +1415,10 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1574,14 +1428,10 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1590,25 +1440,17 @@ def __call__( # manually for max memory savings if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to( - next(iter(self.vae.post_quant_conv.parameters())).dtype - ) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = ( - self.vae.dtype == torch.float16 and self.vae.config.force_upcast - ) + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() - latents = latents.to( - next(iter(self.vae.post_quant_conv.parameters())).dtype - ) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode( - latents / self.vae.config.scaling_factor, return_dict=False - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f66b0ce9141f..6e7a867f1078 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -71,16 +71,12 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std( - dim=list(range(1, noise_pred_text.ndim)), keepdim=True - ) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - ) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @@ -143,10 +139,7 @@ def __init__( ): super().__init__() - if ( - hasattr(scheduler.config, "steps_offset") - and scheduler.config.steps_offset != 1 - ): + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -155,17 +148,12 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if ( - hasattr(scheduler.config, "clip_sample") - and scheduler.config.clip_sample is True - ): + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -173,9 +161,7 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate( - "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -196,16 +182,10 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr( - unet.config, "_diffusers_version" - ) and version.parse( + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse( - "0.9.0.dev0" - ) - is_unet_sample_size_less_64 = ( - hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - ) + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -218,9 +198,7 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate( - "sample_size<64", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -372,13 +350,11 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -387,18 +363,13 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( @@ -414,9 +385,7 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -430,9 +399,7 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -468,10 +435,7 @@ def encode_prompt( return_tensors="pt", ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -486,16 +450,10 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -517,16 +475,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): @@ -534,14 +488,10 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess( - image, output_type="pil" - ) + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor( - feature_extractor_input, return_tensors="pt" - ).to(device) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -564,17 +514,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -591,20 +537,15 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -619,12 +560,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -664,9 +601,7 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) @@ -905,9 +840,7 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -929,9 +862,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -956,16 +887,12 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -976,14 +903,8 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( @@ -999,9 +920,7 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -1012,9 +931,7 @@ def __call__( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1024,14 +941,10 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1043,9 +956,7 @@ def __call__( return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -1055,9 +966,7 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess( - image, output_type=output_type, do_denormalize=do_denormalize - ) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() @@ -1065,6 +974,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 797802b37043..ad45ba2beee6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -104,10 +104,7 @@ def preprocess(image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - image = [ - np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] - for i in image - ] + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) @@ -177,10 +174,7 @@ def __init__( ): super().__init__() - if ( - hasattr(scheduler.config, "steps_offset") - and scheduler.config.steps_offset != 1 - ): + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -189,17 +183,12 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if ( - hasattr(scheduler.config, "clip_sample") - and scheduler.config.clip_sample is True - ): + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -207,9 +196,7 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate( - "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -230,16 +217,10 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr( - unet.config, "_diffusers_version" - ) and version.parse( + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse( - "0.9.0.dev0" - ) - is_unet_sample_size_less_64 = ( - hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - ) + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -252,9 +233,7 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate( - "sample_size<64", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -379,13 +358,11 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -394,18 +371,13 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( @@ -421,9 +393,7 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -437,9 +407,7 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -475,10 +443,7 @@ def encode_prompt( return_tensors="pt", ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -493,16 +458,10 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -525,16 +484,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker @@ -543,14 +498,10 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess( - image, output_type="pil" - ) + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor( - feature_extractor_input, return_tensors="pt" - ).to(device) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -575,17 +526,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -601,21 +548,16 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError( - f"The value of strength should in [0.0, 1.0] but is {strength}" - ) + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -629,12 +571,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -690,23 +628,16 @@ def prepare_latents( elif isinstance(generator, list): init_latents = [ - retrieve_latents( - self.vae.encode(image[i : i + 1]), generator=generator[i] - ) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents( - self.vae.encode(image), generator=generator - ) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents - if ( - batch_size > init_latents.shape[0] - and batch_size % init_latents.shape[0] == 0 - ): + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" @@ -721,13 +652,8 @@ def prepare_latents( standard_warn=False, ) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat( - [init_latents] * additional_image_per_prompt, dim=0 - ) - elif ( - batch_size > init_latents.shape[0] - and batch_size % init_latents.shape[0] != 0 - ): + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) @@ -962,9 +888,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -984,9 +908,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -995,9 +917,7 @@ def __call__( # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps, strength, device - ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -1015,16 +935,12 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1035,14 +951,8 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( @@ -1058,14 +968,10 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1075,14 +981,10 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1094,9 +996,7 @@ def __call__( return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -1106,9 +1006,7 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess( - image, output_type=output_type, do_denormalize=do_denormalize - ) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() @@ -1116,6 +1014,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 737a2d6938b7..78de0b69b303 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -58,9 +58,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image( - image, mask, height, width, return_image: bool = False -): +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -102,15 +100,11 @@ def prepare_mask_and_masked_image( if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): - raise TypeError( - f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" - ) + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: - assert ( - image.shape[0] == 3 - ), "Image outside a batch should be of shape (3, H, W)" + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask @@ -127,15 +121,9 @@ def prepare_mask_and_masked_image( else: mask = mask.unsqueeze(1) - assert ( - image.ndim == 4 and mask.ndim == 4 - ), "Image and Mask must have 4 dimensions" - assert ( - image.shape[-2:] == mask.shape[-2:] - ), "Image and Mask must have the same spatial dimensions" - assert ( - image.shape[0] == mask.shape[0] - ), "Image and Mask must have the same batch size" + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: @@ -152,18 +140,14 @@ def prepare_mask_and_masked_image( # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): - raise TypeError( - f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" - ) + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width - image = [ - i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image - ] + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -178,9 +162,7 @@ def prepare_mask_and_masked_image( if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate( - [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 - ) + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) @@ -272,10 +254,7 @@ def __init__( ): super().__init__() - if ( - hasattr(scheduler.config, "steps_offset") - and scheduler.config.steps_offset != 1 - ): + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -284,17 +263,12 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if ( - hasattr(scheduler.config, "skip_prk_steps") - and scheduler.config.skip_prk_steps is False - ): + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" @@ -329,16 +303,10 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr( - unet.config, "_diffusers_version" - ) and version.parse( + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse( - "0.9.0.dev0" - ) - is_unet_sample_size_less_64 = ( - hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - ) + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -351,18 +319,14 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate( - "sample_size<64", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: - logger.info( - f"You have loaded a UNet with {unet.config.in_channels} input channels which." - ) + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( vae=vae, @@ -490,13 +454,11 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) @@ -505,18 +467,13 @@ def encode_prompt( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( @@ -532,9 +489,7 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -548,9 +503,7 @@ def encode_prompt( bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -586,10 +539,7 @@ def encode_prompt( return_tensors="pt", ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -604,16 +554,10 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -636,16 +580,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker @@ -654,14 +594,10 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess( - image, output_type="pil" - ) + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor( - feature_extractor_input, return_tensors="pt" - ).to(device) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) @@ -674,17 +610,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -702,26 +634,19 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError( - f"The value of strength should in [0.0, 1.0] but is {strength}" - ) + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -736,12 +661,8 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -798,24 +719,14 @@ def prepare_latents( image_latents = image else: image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat( - batch_size // image_latents.shape[0], 1, 1, 1 - ) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = ( - noise - if is_strength_max - else self.scheduler.add_noise(image_latents, noise, timestep) - ) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = ( - latents * self.scheduler.init_noise_sigma - if is_strength_max - else latents - ) + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma @@ -833,16 +744,12 @@ def prepare_latents( def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - retrieve_latents( - self.vae.encode(image[i : i + 1]), generator=generator[i] - ) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents( - self.vae.encode(image), generator=generator - ) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = self.vae.config.scaling_factor * image_latents @@ -873,9 +780,7 @@ def prepare_mask_latents( if masked_image.shape[1] == 4: masked_image_latents = masked_image else: - masked_image_latents = self._encode_vae_image( - masked_image, generator=generator - ) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -893,15 +798,11 @@ def prepare_mask_latents( f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( - torch.cat([masked_image_latents] * 2) - if do_classifier_free_guidance - else masked_image_latents + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input @@ -1192,9 +1093,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) - if cross_attention_kwargs is not None - else None + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -1214,9 +1113,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -1268,9 +1165,7 @@ def __call__( latents, noise = latents_outputs # 7. Prepare mask latent variables - mask_condition = self.mask_processor.preprocess( - mask_image, height=height, width=width - ) + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) if masked_image_latents is None: masked_image = init_image * (mask_condition < 0.5) @@ -1294,10 +1189,7 @@ def __call__( # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] - if ( - num_channels_latents + num_channels_mask + num_channels_masked_image - != self.unet.config.in_channels - ): + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" @@ -1314,16 +1206,12 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1334,21 +1222,13 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if num_channels_unet == 9: - latent_model_input = torch.cat( - [latent_model_input, mask, masked_image_latents], dim=1 - ) + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet( @@ -1364,14 +1244,10 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents if self.do_classifier_free_guidance: @@ -1385,9 +1261,7 @@ def __call__( init_latents_proper, noise, torch.tensor([noise_timestep]) ) - latents = ( - 1 - init_mask - ) * init_latents_proper + init_mask * latents + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} @@ -1397,18 +1271,12 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) mask = callback_outputs.pop("mask", mask) - masked_image_latents = callback_outputs.pop( - "masked_image_latents", masked_image_latents - ) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1417,14 +1285,10 @@ def __call__( if not output_type == "latent": condition_kwargs = {} if isinstance(self.vae, AsymmetricAutoencoderKL): - init_image = init_image.to( - device=device, dtype=masked_image_latents.dtype - ) + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) init_image_condition = init_image.clone() init_image = self._encode_vae_image(init_image, generator=generator) - mask_condition = mask_condition.to( - device=device, dtype=masked_image_latents.dtype - ) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) condition_kwargs = { "image": init_image_condition, "mask": mask_condition, @@ -1435,9 +1299,7 @@ def __call__( generator=generator, **condition_kwargs, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -1447,9 +1309,7 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess( - image, output_type=output_type, do_denormalize=do_denormalize - ) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() @@ -1457,6 +1317,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 37f56c80ec15..e742d8007462 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -91,16 +91,12 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std( - dim=list(range(1, noise_pred_text.ndim)), keepdim=True - ) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - ) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @@ -202,19 +198,13 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) - self.register_to_config( - force_zeros_for_empty_prompt=force_zeros_for_empty_prompt - ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size - add_watermarker = ( - add_watermarker - if add_watermarker is not None - else is_invisible_watermark_available() - ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() @@ -316,9 +306,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance( - self, StableDiffusionXLLoraLoaderMixin - ): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -342,15 +330,9 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = ( - [self.tokenizer, self.tokenizer_2] - if self.tokenizer is not None - else [self.tokenizer_2] - ) + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( - [self.text_encoder, self.text_encoder_2] - if self.text_encoder is not None - else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: @@ -360,9 +342,7 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip( - prompts, tokenizers, text_encoders - ): + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -375,24 +355,18 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode( - untruncated_ids[:, tokenizer.model_max_length - 1 : -1] - ) + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder( - text_input_ids.to(device), output_hidden_states=True - ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -407,14 +381,8 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = ( - negative_prompt is None and self.config.force_zeros_for_empty_prompt - ) - if ( - do_classifier_free_guidance - and negative_prompt_embeds is None - and zero_out_negative_prompt - ): + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -422,15 +390,9 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = ( - batch_size * [negative_prompt] - if isinstance(negative_prompt, str) - else negative_prompt - ) + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( - batch_size * [negative_prompt_2] - if isinstance(negative_prompt_2, str) - else negative_prompt_2 + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] @@ -449,13 +411,9 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip( - uncond_tokens, tokenizers, text_encoders - ): + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt( - negative_prompt, tokenizer - ) + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -479,46 +437,34 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.unet.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -552,16 +498,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -571,17 +513,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -602,21 +540,16 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -636,18 +569,10 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - elif prompt_2 is not None and ( - not isinstance(prompt_2, str) and not isinstance(prompt_2, list) - ): - raise ValueError( - f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -703,9 +628,7 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) @@ -724,8 +647,7 @@ def _get_add_time_ids( add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) - + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1071,9 +993,7 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( @@ -1145,29 +1065,21 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, add_text_embeds], dim=0 - ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat( - batch_size * num_images_per_prompt, 1 - ) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) # 8. Denoising loop - num_warmup_steps = max( - len(timesteps) - num_inference_steps * self.scheduler.order, 0 - ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 8.1 Apply denoising_end if ( @@ -1182,17 +1094,13 @@ def __call__( - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len( - list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) - ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1201,15 +1109,9 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = { @@ -1231,9 +1133,7 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -1244,9 +1144,7 @@ def __call__( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1256,24 +1154,16 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) - add_text_embeds = callback_outputs.pop( - "add_text_embeds", add_text_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop( - "negative_add_time_ids", negative_add_time_ids - ) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1284,19 +1174,13 @@ def __call__( if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = ( - self.vae.dtype == torch.float16 and self.vae.config.force_upcast - ) + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() - latents = latents.to( - next(iter(self.vae.post_quant_conv.parameters())).dtype - ) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode( - latents / self.vae.config.scaling_factor, return_dict=False - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b00aa4fdd0e9..9b2b354170f0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -95,16 +95,12 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std( - dim=list(range(1, noise_pred_text.ndim)), keepdim=True - ) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - ) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @@ -220,18 +216,12 @@ def __init__( feature_extractor=feature_extractor, scheduler=scheduler, ) - self.register_to_config( - force_zeros_for_empty_prompt=force_zeros_for_empty_prompt - ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - add_watermarker = ( - add_watermarker - if add_watermarker is not None - else is_invisible_watermark_available() - ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() @@ -334,9 +324,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance( - self, StableDiffusionXLLoraLoaderMixin - ): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -360,15 +348,9 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = ( - [self.tokenizer, self.tokenizer_2] - if self.tokenizer is not None - else [self.tokenizer_2] - ) + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( - [self.text_encoder, self.text_encoder_2] - if self.text_encoder is not None - else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: @@ -378,9 +360,7 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip( - prompts, tokenizers, text_encoders - ): + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -393,24 +373,18 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode( - untruncated_ids[:, tokenizer.model_max_length - 1 : -1] - ) + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder( - text_input_ids.to(device), output_hidden_states=True - ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -425,14 +399,8 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = ( - negative_prompt is None and self.config.force_zeros_for_empty_prompt - ) - if ( - do_classifier_free_guidance - and negative_prompt_embeds is None - and zero_out_negative_prompt - ): + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -440,15 +408,9 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = ( - batch_size * [negative_prompt] - if isinstance(negative_prompt, str) - else negative_prompt - ) + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( - batch_size * [negative_prompt_2] - if isinstance(negative_prompt_2, str) - else negative_prompt_2 + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] @@ -467,13 +429,9 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip( - uncond_tokens, tokenizers, text_encoders - ): + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt( - negative_prompt, tokenizer - ) + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -497,46 +455,34 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.unet.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -562,17 +508,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -591,9 +533,7 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError( - f"The value of strength should in [0.0, 1.0] but is {strength}" - ) + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if num_inference_steps is None: raise ValueError("`num_inference_steps` cannot be None.") elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: @@ -601,17 +541,14 @@ def check_inputs( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -631,18 +568,10 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - elif prompt_2 is not None and ( - not isinstance(prompt_2, str) and not isinstance(prompt_2, list) - ): - raise ValueError( - f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -663,14 +592,10 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def get_timesteps( - self, num_inference_steps, strength, device, denoising_start=None - ): + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 @@ -745,16 +670,12 @@ def prepare_latents( elif isinstance(generator, list): init_latents = [ - retrieve_latents( - self.vae.encode(image[i : i + 1]), generator=generator[i] - ) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = retrieve_latents( - self.vae.encode(image), generator=generator - ) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) @@ -762,19 +683,11 @@ def prepare_latents( init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents - if ( - batch_size > init_latents.shape[0] - and batch_size % init_latents.shape[0] == 0 - ): + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat( - [init_latents] * additional_image_per_prompt, dim=0 - ) - elif ( - batch_size > init_latents.shape[0] - and batch_size % init_latents.shape[0] != 0 - ): + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) @@ -806,16 +719,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds def _get_add_time_ids( @@ -832,38 +741,29 @@ def _get_add_time_ids( text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: - add_time_ids = list( - original_size + crops_coords_top_left + (aesthetic_score,) - ) + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( - negative_original_size - + negative_crops_coords_top_left - + (negative_aesthetic_score,) + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list( - negative_original_size + crops_coords_top_left + negative_target_size - ) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) - + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) - == self.unet.config.addition_time_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) - == self.unet.config.addition_time_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." @@ -1229,9 +1129,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, @@ -1320,12 +1218,8 @@ def denoising_value_valid(dnv): if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, add_text_embeds], dim=0 - ) - add_neg_time_ids = add_neg_time_ids.repeat( - batch_size * num_images_per_prompt, 1 - ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) @@ -1333,17 +1227,13 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) # 9. Denoising loop - num_warmup_steps = max( - len(timesteps) - num_inference_steps * self.scheduler.order, 0 - ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 9.1 Apply denoising_end if ( @@ -1357,26 +1247,20 @@ def denoising_value_valid(dnv): f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) - elif self.denoising_end is not None and denoising_value_valid( - self.denoising_end - ): + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len( - list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) - ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1385,15 +1269,9 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = { @@ -1415,9 +1293,7 @@ def denoising_value_valid(dnv): # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -1428,9 +1304,7 @@ def denoising_value_valid(dnv): ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -1440,24 +1314,16 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) - add_text_embeds = callback_outputs.pop( - "add_text_embeds", add_text_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop( - "add_neg_time_ids", add_neg_time_ids - ) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1468,19 +1334,13 @@ def denoising_value_valid(dnv): if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = ( - self.vae.dtype == torch.float16 and self.vae.config.force_upcast - ) + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() - latents = latents.to( - next(iter(self.vae.post_quant_conv.parameters())).dtype - ) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode( - latents / self.vae.config.scaling_factor, return_dict=False - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 560973ece087..6c7443e52c70 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -106,16 +106,12 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std( - dim=list(range(1, noise_pred_text.ndim)), keepdim=True - ) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - ) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @@ -126,9 +122,7 @@ def mask_pil_to_torch(mask, height, width): if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate( - [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 - ) + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) @@ -137,9 +131,7 @@ def mask_pil_to_torch(mask, height, width): return mask -def prepare_mask_and_masked_image( - image, mask, height, width, return_image: bool = False -): +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -202,13 +194,9 @@ def prepare_mask_and_masked_image( else: mask = mask.unsqueeze(1) - assert ( - image.ndim == 4 and mask.ndim == 4 - ), "Image and Mask must have 4 dimensions" + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert ( - image.shape[0] == mask.shape[0] - ), "Image and Mask must have the same batch size" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] # if image.min() < -1 or image.max() > 1: @@ -225,18 +213,14 @@ def prepare_mask_and_masked_image( # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): - raise TypeError( - f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" - ) + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width - image = [ - i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image - ] + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -380,9 +364,7 @@ def __init__( feature_extractor=feature_extractor, scheduler=scheduler, ) - self.register_to_config( - force_zeros_for_empty_prompt=force_zeros_for_empty_prompt - ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -393,11 +375,7 @@ def __init__( do_convert_grayscale=True, ) - add_watermarker = ( - add_watermarker - if add_watermarker is not None - else is_invisible_watermark_available() - ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() @@ -452,16 +430,12 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) else: # IP-Adapter Plus - image_embeds = self.image_encoder( - image, output_hidden_states=True - ).hidden_states[-2] + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_embeds = uncond_image_embeds.repeat_interleave( - num_images_per_prompt, dim=0 - ) + uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ + -2 + ] + uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt @@ -527,9 +501,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance( - self, StableDiffusionXLLoraLoaderMixin - ): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -553,15 +525,9 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = ( - [self.tokenizer, self.tokenizer_2] - if self.tokenizer is not None - else [self.tokenizer_2] - ) + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( - [self.text_encoder, self.text_encoder_2] - if self.text_encoder is not None - else [self.text_encoder_2] + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: @@ -571,9 +537,7 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip( - prompts, tokenizers, text_encoders - ): + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -586,24 +550,18 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode( - untruncated_ids[:, tokenizer.model_max_length - 1 : -1] - ) + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder( - text_input_ids.to(device), output_hidden_states=True - ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] @@ -618,14 +576,8 @@ def encode_prompt( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = ( - negative_prompt is None and self.config.force_zeros_for_empty_prompt - ) - if ( - do_classifier_free_guidance - and negative_prompt_embeds is None - and zero_out_negative_prompt - ): + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: @@ -633,15 +585,9 @@ def encode_prompt( negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list - negative_prompt = ( - batch_size * [negative_prompt] - if isinstance(negative_prompt, str) - else negative_prompt - ) + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( - batch_size * [negative_prompt_2] - if isinstance(negative_prompt_2, str) - else negative_prompt_2 + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] @@ -660,13 +606,9 @@ def encode_prompt( uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip( - uncond_tokens, tokenizers, text_encoders - ): + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt( - negative_prompt, tokenizer - ) + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( @@ -690,46 +632,34 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.text_encoder_2.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=self.unet.dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -755,17 +685,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -785,26 +711,19 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: - raise ValueError( - f"The value of strength should in [0.0, 1.0] but is {strength}" - ) + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -824,18 +743,10 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - elif prompt_2 is not None and ( - not isinstance(prompt_2, str) and not isinstance(prompt_2, list) - ): - raise ValueError( - f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -893,30 +804,18 @@ def prepare_latents( if image.shape[1] == 4: image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat( - batch_size // image_latents.shape[0], 1, 1, 1 - ) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) elif return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat( - batch_size // image_latents.shape[0], 1, 1, 1 - ) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None and add_noise: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = ( - noise - if is_strength_max - else self.scheduler.add_noise(image_latents, noise, timestep) - ) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = ( - latents * self.scheduler.init_noise_sigma - if is_strength_max - else latents - ) + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents elif add_noise: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma @@ -942,16 +841,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - retrieve_latents( - self.vae.encode(image[i : i + 1]), generator=generator[i] - ) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents( - self.vae.encode(image), generator=generator - ) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) @@ -1001,9 +896,7 @@ def prepare_mask_latents( if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image( - masked_image, generator=generator - ) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: @@ -1017,9 +910,7 @@ def prepare_mask_latents( ) masked_image_latents = ( - torch.cat([masked_image_latents] * 2) - if do_classifier_free_guidance - else masked_image_latents + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input @@ -1028,14 +919,10 @@ def prepare_mask_latents( return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps( - self, num_inference_steps, strength, device, denoising_start=None - ): + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 @@ -1083,38 +970,29 @@ def _get_add_time_ids( text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: - add_time_ids = list( - original_size + crops_coords_top_left + (aesthetic_score,) - ) + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( - negative_original_size - + negative_crops_coords_top_left - + (negative_aesthetic_score,) + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list( - negative_original_size + crops_coords_top_left + negative_target_size - ) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) - + text_encoder_projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) - == self.unet.config.addition_time_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) - == self.unet.config.addition_time_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." @@ -1501,9 +1379,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( @@ -1609,10 +1485,7 @@ def denoising_value_valid(dnv): # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] - if ( - num_channels_latents + num_channels_mask + num_channels_masked_image - != self.unet.config.in_channels - ): + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" @@ -1663,12 +1536,8 @@ def denoising_value_valid(dnv): if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, add_text_embeds], dim=0 - ) - add_neg_time_ids = add_neg_time_ids.repeat( - batch_size * num_images_per_prompt, 1 - ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) @@ -1676,17 +1545,13 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) # 11. Denoising loop - num_warmup_steps = max( - len(timesteps) - num_inference_steps * self.scheduler.order, 0 - ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) if ( self.denoising_end is not None @@ -1699,26 +1564,20 @@ def denoising_value_valid(dnv): f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) - elif self.denoising_end is not None and denoising_value_valid( - self.denoising_end - ): + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len( - list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) - ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 11.1 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( - batch_size * num_images_per_prompt - ) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) @@ -1727,21 +1586,13 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if num_channels_unet == 9: - latent_model_input = torch.cat( - [latent_model_input, mask, masked_image_latents], dim=1 - ) + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual added_cond_kwargs = { @@ -1763,9 +1614,7 @@ def denoising_value_valid(dnv): # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -1776,9 +1625,7 @@ def denoising_value_valid(dnv): ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents @@ -1793,9 +1640,7 @@ def denoising_value_valid(dnv): init_latents_proper, noise, torch.tensor([noise_timestep]) ) - latents = ( - 1 - init_mask - ) * init_latents_proper + init_mask * latents + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} @@ -1805,28 +1650,18 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) - add_text_embeds = callback_outputs.pop( - "add_text_embeds", add_text_embeds - ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop( - "add_neg_time_ids", add_neg_time_ids - ) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) mask = callback_outputs.pop("mask", mask) - masked_image_latents = callback_outputs.pop( - "masked_image_latents", masked_image_latents - ) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1837,19 +1672,13 @@ def denoising_value_valid(dnv): if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = ( - self.vae.dtype == torch.float16 and self.vae.config.force_upcast - ) + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() - latents = latents.to( - next(iter(self.vae.post_quant_conv.parameters())).dtype - ) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode( - latents / self.vae.config.scaling_factor, return_dict=False - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index fc349a182dc4..25f4b801af4b 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -55,11 +55,7 @@ def create_ip_adapter_state_dict(model): key_id = 1 for name in model.attn_processors.keys(): - cross_attention_dim = ( - None - if name.endswith("attn1.processor") - else model.config.cross_attention_dim - ) + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -119,11 +115,7 @@ def create_ip_adapter_plus_state_dict(model): key_id = 1 for name in model.attn_processors.keys(): - cross_attention_dim = ( - None - if name.endswith("attn1.processor") - else model.config.cross_attention_dim - ) + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -176,11 +168,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): st = model.state_dict() for name, _ in model.attn_processors.items(): - cross_attention_dim = ( - None - if name.endswith("attn1.processor") - else model.config.cross_attention_dim - ) + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -196,12 +184,8 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): } if train_q_out: weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] - weights["to_out_custom_diffusion.0.weight"] = st[ - layer_name + ".to_out.0.weight" - ] - weights["to_out_custom_diffusion.0.bias"] = st[ - layer_name + ".to_out.0.bias" - ] + weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"] + weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"] if cross_attention_dim is not None: custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor( train_kv=train_kv, @@ -280,9 +264,7 @@ def test_xformers_enable_works(self): model.enable_xformers_memory_efficient_attention() assert ( - model.mid_block.attentions[0] - .transformer_blocks[0] - .attn1.processor.__class__.__name__ + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ == "XFormersAttnProcessor" ), "xformers is not enabled" @@ -325,11 +307,7 @@ def test_gradient_checkpointing(self): named_params = dict(model.named_parameters()) named_params_2 = dict(model_2.named_parameters()) for name, param in named_params.items(): - self.assertTrue( - torch_all_close( - param.grad.data, named_params_2[name].grad.data, atol=5e-5 - ) - ) + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) def test_model_with_attention_head_dim_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -348,9 +326,7 @@ def test_model_with_attention_head_dim_tuple(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual( - output.shape, expected_shape, "Input and output shapes do not match" - ) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_model_with_use_linear_projection(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -369,9 +345,7 @@ def test_model_with_use_linear_projection(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual( - output.shape, expected_shape, "Input and output shapes do not match" - ) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_model_with_cross_attention_dim_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -390,9 +364,7 @@ def test_model_with_cross_attention_dim_tuple(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual( - output.shape, expected_shape, "Input and output shapes do not match" - ) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_model_with_simple_projection(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -402,9 +374,7 @@ def test_model_with_simple_projection(self): init_dict["class_embed_type"] = "simple_projection" init_dict["projection_class_embeddings_input_dim"] = sample_size - inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to( - torch_device - ) + inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) model = self.model_class(**init_dict) model.to(torch_device) @@ -418,9 +388,7 @@ def test_model_with_simple_projection(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual( - output.shape, expected_shape, "Input and output shapes do not match" - ) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_model_with_class_embeddings_concat(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -431,9 +399,7 @@ def test_model_with_class_embeddings_concat(self): init_dict["projection_class_embeddings_input_dim"] = sample_size init_dict["class_embeddings_concat"] = True - inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to( - torch_device - ) + inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device) model = self.model_class(**init_dict) model.to(torch_device) @@ -447,9 +413,7 @@ def test_model_with_class_embeddings_concat(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape - self.assertEqual( - output.shape, expected_shape, "Input and output shapes do not match" - ) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_model_attention_slicing(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -547,17 +511,11 @@ def __call__( number=None, ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -620,33 +578,21 @@ def test_model_xattn_mask(self, mask_dtype): full_cond_out = model(**inputs_dict).sample assert full_cond_out is not None - keepall_mask = torch.ones( - *cond.shape[:-1], device=cond.device, dtype=mask_dtype - ) - full_cond_keepallmask_out = model( - **{**inputs_dict, "encoder_attention_mask": keepall_mask} - ).sample + keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) + full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample assert full_cond_keepallmask_out.allclose( full_cond_out, rtol=1e-05, atol=1e-05 ), "a 'keep all' mask should give the same result as no mask" trunc_cond = cond[:, :-1, :] - trunc_cond_out = model( - **{**inputs_dict, "encoder_hidden_states": trunc_cond} - ).sample + trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample assert not trunc_cond_out.allclose( full_cond_out, rtol=1e-05, atol=1e-05 ), "discarding the last token from our cond should change the result" batch, tokens, _ = cond.shape - mask_last = ( - (torch.arange(tokens) < tokens - 1) - .expand(batch, -1) - .to(cond.device, mask_dtype) - ) - masked_cond_out = model( - **{**inputs_dict, "encoder_attention_mask": mask_last} - ).sample + mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) + masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample assert masked_cond_out.allclose( trunc_cond_out, rtol=1e-05, atol=1e-05 ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" @@ -671,24 +617,12 @@ def test_model_xattn_padding(self): assert full_cond_out is not None batch, tokens, _ = cond.shape - keeplast_mask = ( - (torch.arange(tokens) == tokens - 1) - .expand(batch, -1) - .to(cond.device, torch.bool) - ) - keeplast_out = model( - **{**inputs_dict, "encoder_attention_mask": keeplast_mask} - ).sample - assert not keeplast_out.allclose( - full_cond_out - ), "a 'keep last token' mask should change the result" - - trunc_mask = torch.zeros( - batch, tokens - 1, device=cond.device, dtype=torch.bool - ) - trunc_mask_out = model( - **{**inputs_dict, "encoder_attention_mask": trunc_mask} - ).sample + keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool) + keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample + assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result" + + trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) + trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample assert trunc_mask_out.allclose( keeplast_out ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." @@ -705,9 +639,7 @@ def test_custom_diffusion_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - custom_diffusion_attn_procs = create_custom_diffusion_layers( - model, mock_weights=False - ) + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) # make sure we can set a list of attention processors model.set_attn_processor(custom_diffusion_attn_procs) @@ -734,9 +666,7 @@ def test_custom_diffusion_save_load(self): with torch.no_grad(): old_sample = model(**inputs_dict).sample - custom_diffusion_attn_procs = create_custom_diffusion_layers( - model, mock_weights=False - ) + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) model.set_attn_processor(custom_diffusion_attn_procs) with torch.no_grad(): @@ -744,16 +674,10 @@ def test_custom_diffusion_save_load(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_attn_procs(tmpdirname, safe_serialization=False) - self.assertTrue( - os.path.isfile( - os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin") - ) - ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) - new_model.load_attn_procs( - tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin" - ) + new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") new_model.to(torch_device) with torch.no_grad(): @@ -777,9 +701,7 @@ def test_custom_diffusion_xformers_on_off(self): torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) - custom_diffusion_attn_procs = create_custom_diffusion_layers( - model, mock_weights=False - ) + custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False) model.set_attn_processor(custom_diffusion_attn_procs) # default @@ -825,9 +747,7 @@ def test_asymmetrical_unet(self): expected_shape = inputs_dict["sample"].shape # Check if input and output shapes are the same - self.assertEqual( - output.shape, expected_shape, "Input and output shapes do not match" - ) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_ip_adapter(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -843,20 +763,14 @@ def test_ip_adapter(self): # update inputs_dict for ip-adapter batch_size = inputs_dict["encoder_hidden_states"].shape[0] - image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to( - torch_device - ) + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} # make ip_adapter_1 and ip_adapter_2 ip_adapter_1 = create_ip_adapter_state_dict(model) - image_proj_state_dict_2 = { - k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items() - } - cross_attn_state_dict_2 = { - k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items() - } + image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} + cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} ip_adapter_2 = {} ip_adapter_2.update( { @@ -869,9 +783,7 @@ def test_ip_adapter(self): model._load_ip_adapter_weights(ip_adapter_1) assert model.config.encoder_hid_dim_type == "ip_image_proj" assert model.encoder_hid_proj is not None - assert model.down_blocks[0].attentions[0].transformer_blocks[ - 0 - ].attn2.processor.__class__.__name__ in ( + assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( "IPAdapterAttnProcessor", "IPAdapterAttnProcessor2_0", ) @@ -906,20 +818,14 @@ def test_ip_adapter_plus(self): # update inputs_dict for ip-adapter batch_size = inputs_dict["encoder_hidden_states"].shape[0] - image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to( - torch_device - ) + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} # make ip_adapter_1 and ip_adapter_2 ip_adapter_1 = create_ip_adapter_plus_state_dict(model) - image_proj_state_dict_2 = { - k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items() - } - cross_attn_state_dict_2 = { - k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items() - } + image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} + cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} ip_adapter_2 = {} ip_adapter_2.update( { @@ -932,9 +838,7 @@ def test_ip_adapter_plus(self): model._load_ip_adapter_weights(ip_adapter_1) assert model.config.encoder_hid_dim_type == "ip_image_proj" assert model.encoder_hid_proj is not None - assert model.down_blocks[0].attentions[0].transformer_blocks[ - 0 - ].attn2.processor.__class__.__name__ in ( + assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( "IPAdapterAttnProcessor", "IPAdapterAttnProcessor2_0", ) @@ -969,11 +873,7 @@ def tearDown(self): def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): dtype = torch.float16 if fp16 else torch.float32 - image = ( - torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))) - .to(torch_device) - .to(dtype) - ) + image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) return image def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): @@ -1000,9 +900,7 @@ def test_set_attention_slice_auto(self): timestep = 1 with torch.no_grad(): - _ = unet( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -1021,9 +919,7 @@ def test_set_attention_slice_max(self): timestep = 1 with torch.no_grad(): - _ = unet( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -1042,9 +938,7 @@ def test_set_attention_slice_int(self): timestep = 1 with torch.no_grad(): - _ = unet( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -1065,9 +959,7 @@ def test_set_attention_slice_list(self): timestep = 1 with torch.no_grad(): - _ = unet( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() @@ -1075,11 +967,7 @@ def test_set_attention_slice_list(self): def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): dtype = torch.float16 if fp16 else torch.float32 - hidden_states = ( - torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))) - .to(torch_device) - .to(dtype) - ) + hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) return hidden_states @parameterized.expand( @@ -1101,9 +989,7 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample assert sample.shape == latents.shape @@ -1131,9 +1017,7 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample assert sample.shape == latents.shape @@ -1161,9 +1045,7 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample assert sample.shape == latents.shape @@ -1184,18 +1066,14 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): ) @require_torch_gpu def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model( - model_id="runwayml/stable-diffusion-v1-5", fp16=True - ) + model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True) latents = self.get_latents(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample assert sample.shape == latents.shape @@ -1223,9 +1101,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample assert sample.shape == (4, 4, 64, 64) @@ -1246,18 +1122,14 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): ) @require_torch_gpu def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model( - model_id="runwayml/stable-diffusion-inpainting", fp16=True - ) + model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True) latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample assert sample.shape == (4, 4, 64, 64) @@ -1278,20 +1150,14 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): ) @require_torch_gpu def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): - model = self.get_unet_model( - model_id="stabilityai/stable-diffusion-2", fp16=True - ) + model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states( - seed, shape=(4, 77, 1024), fp16=True - ) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) with torch.no_grad(): - sample = model( - latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states - ).sample + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample assert sample.shape == latents.shape diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py index 15f9b9f2db52..cdf6c86d9c7f 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py @@ -61,9 +61,7 @@ def get_image_processor(self, repo_id): image_processor = CLIPImageProcessor.from_pretrained(repo_id) return image_processor - def get_dummy_inputs( - self, for_image_to_image=False, for_inpainting=False, for_sdxl=False - ): + def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False): image = load_image( "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" ) @@ -79,12 +77,8 @@ def get_dummy_inputs( "output_type": "np", } if for_image_to_image: - image = load_image( - "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg" - ) - ip_image = load_image( - "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png" - ) + image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg") + ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png") if for_sdxl: image = image.resize((1024, 1024)) @@ -93,24 +87,16 @@ def get_dummy_inputs( input_kwargs.update({"image": image, "ip_adapter_image": ip_image}) elif for_inpainting: - image = load_image( - "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png" - ) - mask = load_image( - "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png" - ) - ip_image = load_image( - "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png" - ) + image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png") + mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png") + ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png") if for_sdxl: image = image.resize((1024, 1024)) mask = mask.resize((1024, 1024)) ip_image = ip_image.resize((1024, 1024)) - input_kwargs.update( - {"image": image, "mask_image": mask, "ip_adapter_image": ip_image} - ) + input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image}) return input_kwargs @@ -119,9 +105,7 @@ def get_dummy_inputs( @require_torch_gpu class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): def test_text_to_image(self): - image_encoder = self.get_image_encoder( - repo_id="h94/IP-Adapter", subfolder="models/image_encoder" - ) + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, @@ -129,24 +113,18 @@ def test_text_to_image(self): torch_dtype=self.dtype, ) pipeline.to(torch_device) - pipeline.load_ip_adapter( - "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin" - ) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") inputs = self.get_dummy_inputs() images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array( - [0.3015, 0.2615, 0.2200, 0.2725, 0.2510, 0.2021, 0.2498, 0.2415, 0.2131] - ) + expected_slice = np.array([0.3015, 0.2615, 0.2200, 0.2725, 0.2510, 0.2021, 0.2498, 0.2415, 0.2131]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) def test_image_to_image(self): - image_encoder = self.get_image_encoder( - repo_id="h94/IP-Adapter", subfolder="models/image_encoder" - ) + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, @@ -154,24 +132,18 @@ def test_image_to_image(self): torch_dtype=self.dtype, ) pipeline.to(torch_device) - pipeline.load_ip_adapter( - "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin" - ) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") inputs = self.get_dummy_inputs(for_image_to_image=True) images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array( - [0.3518, 0.2554, 0.2495, 0.2363, 0.1836, 0.3823, 0.1414, 0.1868, 0.5386] - ) + expected_slice = np.array([0.3518, 0.2554, 0.2495, 0.2363, 0.1836, 0.3823, 0.1414, 0.1868, 0.5386]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) def test_inpainting(self): - image_encoder = self.get_image_encoder( - repo_id="h94/IP-Adapter", subfolder="models/image_encoder" - ) + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, @@ -179,17 +151,13 @@ def test_inpainting(self): torch_dtype=self.dtype, ) pipeline.to(torch_device) - pipeline.load_ip_adapter( - "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin" - ) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") inputs = self.get_dummy_inputs(for_inpainting=True) images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array( - [0.2756, 0.2422, 0.2214, 0.2346, 0.2102, 0.2060, 0.2188, 0.2043, 0.1941] - ) + expected_slice = np.array([0.2756, 0.2422, 0.2214, 0.2346, 0.2102, 0.2060, 0.2188, 0.2043, 0.1941]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -198,12 +166,8 @@ def test_inpainting(self): @require_torch_gpu class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): def test_text_to_image_sdxl(self): - image_encoder = self.get_image_encoder( - repo_id="h94/IP-Adapter", subfolder="models/image_encoder" - ) - feature_extractor = self.get_image_processor( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - ) + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", @@ -222,19 +186,13 @@ def test_text_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array( - [0.0587, 0.0567, 0.0454, 0.0537, 0.0553, 0.0518, 0.0494, 0.0535, 0.0497] - ) + expected_slice = np.array([0.0587, 0.0567, 0.0454, 0.0537, 0.0553, 0.0518, 0.0494, 0.0535, 0.0497]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) def test_image_to_image_sdxl(self): - image_encoder = self.get_image_encoder( - repo_id="h94/IP-Adapter", subfolder="models/image_encoder" - ) - feature_extractor = self.get_image_processor( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - ) + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", @@ -253,19 +211,13 @@ def test_image_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array( - [0.0711, 0.0700, 0.0734, 0.0758, 0.0742, 0.0688, 0.0751, 0.0827, 0.0851] - ) + expected_slice = np.array([0.0711, 0.0700, 0.0734, 0.0758, 0.0742, 0.0688, 0.0751, 0.0827, 0.0851]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) def test_inpainting_sdxl(self): - image_encoder = self.get_image_encoder( - repo_id="h94/IP-Adapter", subfolder="models/image_encoder" - ) - feature_extractor = self.get_image_processor( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - ) + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", @@ -285,8 +237,6 @@ def test_inpainting_sdxl(self): image_slice = images[0, :3, :3, -1].flatten() image_slice.tolist() - expected_slice = np.array( - [0.1398, 0.1476, 0.1407, 0.1441, 0.1470, 0.1480, 0.1448, 0.1481, 0.1494] - ) + expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1441, 0.1470, 0.1480, 0.1448, 0.1481, 0.1494]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) From 07dbb437a2e0dcfb14ea30e8d9162c1c5adb9b81 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 24 Nov 2023 02:41:49 +0000 Subject: [PATCH 03/16] restore before black format --- src/diffusers/loaders/unet.py | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index cf818df8ce7c..0daae8bbf29e 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -22,10 +22,7 @@ from torch import nn from ..models.embeddings import ImageProjection, Resampler -from ..models.modeling_utils import ( - _LOW_CPU_MEM_USAGE_DEFAULT, - load_model_dict_into_meta, -) +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -65,11 +62,7 @@ class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - def load_attn_procs( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in @@ -135,12 +128,7 @@ def load_attn_procs( ``` """ from ..models.attention_processor import CustomDiffusionAttnProcessor - from ..models.lora import ( - LoRACompatibleConv, - LoRACompatibleLinear, - LoRAConv2dLayer, - LoRALinearLayer, - ) + from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) @@ -329,10 +317,7 @@ def load_attn_procs( for key, value_dict in custom_diffusion_grouped_dict.items(): if len(value_dict) == 0: attn_processors[key] = CustomDiffusionAttnProcessor( - train_kv=False, - train_q_out=False, - hidden_size=None, - cross_attention_dim=None, + train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None ) else: cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] @@ -482,11 +467,7 @@ def save_function(weights, filename): is_custom_diffusion = any( isinstance( x, - ( - CustomDiffusionAttnProcessor, - CustomDiffusionAttnProcessor2_0, - CustomDiffusionXFormersAttnProcessor, - ), + (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), ) for (_, x) in self.attn_processors.items() ) From 9096c37f136b5259a5a107e0a48144385654ac8c Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 24 Nov 2023 02:49:55 +0000 Subject: [PATCH 04/16] restore before black format --- src/diffusers/models/__init__.py | 7 +- .../pipeline_alt_diffusion_img2img.py | 49 ++--------- .../animatediff/pipeline_animatediff.py | 45 ++-------- .../controlnet/pipeline_controlnet.py | 69 +++------------ .../controlnet/pipeline_controlnet_sd_xl.py | 56 ++---------- .../pipeline_stable_diffusion.py | 57 +++--------- .../pipeline_stable_diffusion_img2img.py | 50 ++--------- .../pipeline_stable_diffusion_inpaint.py | 87 +++---------------- .../pipeline_stable_diffusion_xl.py | 44 ++-------- .../pipeline_stable_diffusion_xl_img2img.py | 28 +----- .../pipeline_stable_diffusion_xl_inpaint.py | 41 ++------- tests/models/test_models_unet_2d_condition.py | 66 +++----------- 12 files changed, 95 insertions(+), 504 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index bf7cc7ddfe05..7dc945bf6203 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -14,12 +14,7 @@ from typing import TYPE_CHECKING -from ..utils import ( - DIFFUSERS_SLOW_IMPORT, - _LazyModule, - is_flax_available, - is_torch_available, -) +from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 7069d4534a9e..212ef2d5af35 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -19,20 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import ( - CLIPImageProcessor, - CLIPVisionModelWithProjection, - XLMRobertaTokenizer, -) +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - LoraLoaderMixin, - TextualInversionLoaderMixin, -) +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -120,11 +111,7 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionImg2ImgPipeline( - DiffusionPipeline, - TextualInversionLoaderMixin, - IPAdapterMixin, - LoraLoaderMixin, - FromSingleFileMixin, + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image-to-image generation using Alt Diffusion. @@ -383,9 +370,7 @@ def encode_prompt( prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - output_hidden_states=True, + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -595,16 +580,7 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - def prepare_latents( - self, - image, - timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator=None, - ): + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" @@ -643,12 +619,7 @@ def prepare_latents( " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) - deprecate( - "len(prompt) != len(image)", - "1.0.0", - deprecation_message, - standard_warn=False, - ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: @@ -986,11 +957,9 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode( - latents / self.vae.config.scaling_factor, - return_dict=False, - generator=generator, - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 5b1c34fdf831..b6d1bf639c20 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -18,21 +18,11 @@ import numpy as np import torch -from transformers import ( - CLIPImageProcessor, - CLIPTextModel, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import ( - AutoencoderKL, - ImageProjection, - UNet2DConditionModel, - UNetMotionModel, -) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( @@ -43,13 +33,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import ( - USE_PEFT_BACKEND, - BaseOutput, - logging, - scale_lora_layers, - unscale_lora_layers, -) +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -249,9 +233,7 @@ def encode_prompt( prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - output_hidden_states=True, + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -518,16 +500,7 @@ def check_inputs( # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( - self, - batch_size, - num_channels_latents, - num_frames, - height, - width, - dtype, - device, - generator, - latents=None, + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, @@ -648,13 +621,7 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 95093b277439..e7a2145b0303 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -20,26 +20,11 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import ( - CLIPImageProcessor, - CLIPTextModel, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - LoraLoaderMixin, - TextualInversionLoaderMixin, -) -from ...models import ( - AutoencoderKL, - ControlNetModel, - ImageProjection, - UNet2DConditionModel, -) +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -107,11 +92,7 @@ class StableDiffusionControlNetPipeline( - DiffusionPipeline, - TextualInversionLoaderMixin, - LoraLoaderMixin, - IPAdapterMixin, - FromSingleFileMixin, + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -158,12 +139,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: Union[ - ControlNetModel, - List[ControlNetModel], - Tuple[ControlNetModel], - MultiControlNetModel, - ], + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, @@ -205,9 +181,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, - do_convert_rgb=True, - do_normalize=False, + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -373,9 +347,7 @@ def encode_prompt( prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - output_hidden_states=True, + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -743,23 +715,8 @@ def prepare_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - shape = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -1247,11 +1204,9 @@ def __call__( torch.cuda.empty_cache() if not output_type == "latent": - image = self.vae.decode( - latents / self.vae.config.scaling_factor, - return_dict=False, - generator=generator, - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9ebc35772b93..5a1e2b6e0e26 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -37,12 +37,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import ( - AutoencoderKL, - ControlNetModel, - ImageProjection, - UNet2DConditionModel, -) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -187,12 +182,7 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: Union[ - ControlNetModel, - List[ControlNetModel], - Tuple[ControlNetModel], - MultiControlNetModel, - ], + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -219,9 +209,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, - do_convert_rgb=True, - do_normalize=False, + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -498,12 +486,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): @@ -788,23 +771,8 @@ def prepare_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - shape = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -822,12 +790,7 @@ def prepare_latents( # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( - self, - original_size, - crops_coords_top_left, - target_size, - dtype, - text_encoder_projection_dim=None, + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -1350,10 +1313,7 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - added_cond_kwargs = { - "text_embeds": add_text_embeds, - "time_ids": add_time_ids, - } + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 6e7a867f1078..f6332a5cdfdd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -17,21 +17,11 @@ import torch from packaging import version -from transformers import ( - CLIPImageProcessor, - CLIPTextModel, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - LoraLoaderMixin, - TextualInversionLoaderMixin, -) +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -81,11 +71,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionPipeline( - DiffusionPipeline, - TextualInversionLoaderMixin, - LoraLoaderMixin, - IPAdapterMixin, - FromSingleFileMixin, + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -373,9 +359,7 @@ def encode_prompt( prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - output_hidden_states=True, + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -577,23 +561,8 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - shape = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -924,11 +893,7 @@ def __call__( if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg( - noise_pred, - noise_pred_text, - guidance_rescale=self.guidance_rescale, - ) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] @@ -951,11 +916,9 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode( - latents / self.vae.config.scaling_factor, - return_dict=False, - generator=generator, - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ad45ba2beee6..3062df003cd6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -19,21 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import ( - CLIPImageProcessor, - CLIPTextModel, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - LoraLoaderMixin, - TextualInversionLoaderMixin, -) +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -116,11 +106,7 @@ def preprocess(image): class StableDiffusionImg2ImgPipeline( - DiffusionPipeline, - TextualInversionLoaderMixin, - IPAdapterMixin, - LoraLoaderMixin, - FromSingleFileMixin, + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image-to-image generation using Stable Diffusion. @@ -381,9 +367,7 @@ def encode_prompt( prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - output_hidden_states=True, + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -597,16 +581,7 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - def prepare_latents( - self, - image, - timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator=None, - ): + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" @@ -645,12 +620,7 @@ def prepare_latents( " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) - deprecate( - "len(prompt) != len(image)", - "1.0.0", - deprecation_message, - standard_warn=False, - ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: @@ -991,11 +961,9 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode( - latents / self.vae.config.scaling_factor, - return_dict=False, - generator=generator, - )[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 78de0b69b303..d1d1b6bf4a59 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -19,36 +19,15 @@ import PIL.Image import torch from packaging import version -from transformers import ( - CLIPImageProcessor, - CLIPTextModel, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - LoraLoaderMixin, - TextualInversionLoaderMixin, -) -from ...models import ( - AsymmetricAutoencoderKL, - AutoencoderKL, - ImageProjection, - UNet2DConditionModel, -) +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - scale_lora_layers, - unscale_lora_layers, -) +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -191,11 +170,7 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionInpaintPipeline( - DiffusionPipeline, - TextualInversionLoaderMixin, - IPAdapterMixin, - LoraLoaderMixin, - FromSingleFileMixin, + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. @@ -232,13 +207,7 @@ class StableDiffusionInpaintPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - "mask", - "masked_image_latents", - ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] def __init__( self, @@ -277,12 +246,7 @@ def __init__( " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) - deprecate( - "skip_prk_steps not set", - "1.0.0", - deprecation_message, - standard_warn=False, - ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) @@ -341,10 +305,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, - do_normalize=False, - do_binarize=True, - do_convert_grayscale=True, + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -477,9 +438,7 @@ def encode_prompt( prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - output_hidden_states=True, + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -694,12 +653,7 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -756,16 +710,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - height, - width, - dtype, - device, - generator, - do_classifier_free_guidance, + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload @@ -1289,15 +1234,9 @@ def __call__( init_image_condition = init_image.clone() init_image = self._encode_vae_image(init_image, generator=generator) mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) - condition_kwargs = { - "image": init_image_condition, - "mask": mask_condition, - } + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} image = self.vae.decode( - latents / self.vae.config.scaling_factor, - return_dict=False, - generator=generator, - **condition_kwargs, + latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs )[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e742d8007462..7a4772ef5c98 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -476,12 +476,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): @@ -604,23 +599,8 @@ def check_inputs( ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - shape = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -637,12 +617,7 @@ def prepare_latents( return latents def _get_add_time_ids( - self, - original_size, - crops_coords_top_left, - target_size, - dtype, - text_encoder_projection_dim=None, + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) @@ -1114,10 +1089,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - added_cond_kwargs = { - "text_embeds": add_text_embeds, - "time_ids": add_time_ids, - } + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( @@ -1137,11 +1109,7 @@ def __call__( if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg( - noise_pred, - noise_pred_text, - guidance_rescale=self.guidance_rescale, - ) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 9b2b354170f0..760e0c3eb6a3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -494,12 +494,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -629,15 +624,7 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N return timesteps, num_inference_steps - t_start def prepare_latents( - self, - image, - timestep, - batch_size, - num_images_per_prompt, - dtype, - device, - generator=None, - add_noise=True, + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( @@ -1274,10 +1261,7 @@ def denoising_value_valid(dnv): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - added_cond_kwargs = { - "text_embeds": add_text_embeds, - "time_ids": add_time_ids, - } + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( @@ -1297,11 +1281,7 @@ def denoising_value_valid(dnv): if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg( - noise_pred, - noise_pred_text, - guidance_rescale=self.guidance_rescale, - ) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 6c7443e52c70..5a388e5e0a73 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -369,10 +369,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, - do_normalize=False, - do_binarize=True, - do_convert_grayscale=True, + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -671,12 +668,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) - return ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -784,12 +776,7 @@ def prepare_latents( return_noise=False, return_image_latents=False, ): - shape = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -857,16 +844,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - height, - width, - dtype, - device, - generator, - do_classifier_free_guidance, + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload @@ -1595,10 +1573,7 @@ def denoising_value_valid(dnv): latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual - added_cond_kwargs = { - "text_embeds": add_text_embeds, - "time_ids": add_time_ids, - } + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( @@ -1618,11 +1593,7 @@ def denoising_value_valid(dnv): if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg( - noise_pred, - noise_pred_text, - guidance_rescale=self.guidance_rescale, - ) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 25f4b801af4b..62699ba8e141 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -24,10 +24,7 @@ from pytest import mark from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import ( - CustomDiffusionAttnProcessor, - IPAdapterAttnProcessor, -) +from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor from diffusers.models.embeddings import ImageProjection, Resampler from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available @@ -66,9 +63,7 @@ def create_ip_adapter_state_dict(model): hidden_size = model.config.block_out_channels[block_id] if cross_attention_dim is not None: sd = IPAdapterAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 ).state_dict() ip_cross_attn_state_dict.update( { @@ -82,9 +77,7 @@ def create_ip_adapter_state_dict(model): # "image_proj" (ImageProjection layer weights) cross_attention_dim = model.config["cross_attention_dim"] image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=cross_attention_dim, - num_image_text_embeds=4, + cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, num_image_text_embeds=4 ) ip_image_projection_state_dict = {} @@ -100,12 +93,7 @@ def create_ip_adapter_state_dict(model): del sd ip_state_dict = {} - ip_state_dict.update( - { - "image_proj": ip_image_projection_state_dict, - "ip_adapter": ip_cross_attn_state_dict, - } - ) + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) return ip_state_dict @@ -126,9 +114,7 @@ def create_ip_adapter_plus_state_dict(model): hidden_size = model.config.block_out_channels[block_id] if cross_attention_dim is not None: sd = IPAdapterAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 ).state_dict() ip_cross_attn_state_dict.update( { @@ -142,22 +128,13 @@ def create_ip_adapter_plus_state_dict(model): # "image_proj" (ImageProjection layer weights) cross_attention_dim = model.config["cross_attention_dim"] image_projection = Resampler( - embed_dims=cross_attention_dim, - output_dims=cross_attention_dim, - hidden_dims=32, - num_heads=2, - num_queries=4, + embed_dims=cross_attention_dim, output_dims=cross_attention_dim, hidden_dims=32, num_heads=2, num_queries=4 ) ip_image_projection_state_dict = image_projection.state_dict() ip_state_dict = {} - ip_state_dict.update( - { - "image_proj": ip_image_projection_state_dict, - "ip_adapter": ip_cross_attn_state_dict, - } - ) + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) return ip_state_dict @@ -224,11 +201,7 @@ def dummy_input(self): time_step = torch.tensor([10]).to(torch_device) encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) - return { - "sample": noise, - "timestep": time_step, - "encoder_hidden_states": encoder_hidden_states, - } + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} @property def input_shape(self): @@ -502,14 +475,7 @@ def __init__(self, num): self.number = 0 self.counter = 0 - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - number=None, - ): + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -772,12 +738,7 @@ def test_ip_adapter(self): image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} ip_adapter_2 = {} - ip_adapter_2.update( - { - "image_proj": image_proj_state_dict_2, - "ip_adapter": cross_attn_state_dict_2, - } - ) + ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) # forward pass ip_adapter_1 model._load_ip_adapter_weights(ip_adapter_1) @@ -827,12 +788,7 @@ def test_ip_adapter_plus(self): image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} ip_adapter_2 = {} - ip_adapter_2.update( - { - "image_proj": image_proj_state_dict_2, - "ip_adapter": cross_attn_state_dict_2, - } - ) + ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) # forward pass ip_adapter_1 model._load_ip_adapter_weights(ip_adapter_1) From 21e6275155d7c5878bf4a96f444902942a98e37d Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 30 Nov 2023 00:44:42 +0000 Subject: [PATCH 05/16] generic --- src/diffusers/loaders/ip_adapter.py | 4 ++++ .../alt_diffusion/pipeline_alt_diffusion.py | 24 ++++++++++--------- .../pipeline_alt_diffusion_img2img.py | 24 ++++++++++--------- .../animatediff/pipeline_animatediff.py | 24 ++++++++++--------- .../controlnet/pipeline_controlnet.py | 24 ++++++++++--------- .../controlnet/pipeline_controlnet_sd_xl.py | 24 ++++++++++--------- .../pipeline_stable_diffusion.py | 24 ++++++++++--------- .../pipeline_stable_diffusion_img2img.py | 24 ++++++++++--------- .../pipeline_stable_diffusion_inpaint.py | 24 ++++++++++--------- .../pipeline_stable_diffusion_xl.py | 24 ++++++++++--------- .../pipeline_stable_diffusion_xl_img2img.py | 24 ++++++++++--------- .../pipeline_stable_diffusion_xl_inpaint.py | 24 ++++++++++--------- 12 files changed, 147 insertions(+), 121 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 32c558554be2..85defe44002e 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -148,6 +148,10 @@ def load_ip_adapter( if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: self.feature_extractor = CLIPImageProcessor() + if "proj.weight" not in state_dict["image_proj"]: + # IP-Adapter Plus + self.image_encoder.config.output_hidden_states = True + # load ip-adapter into unet self.unet._load_ip_adapter_weights(state_dict) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index fbf1887555f4..f425d6eb75d2 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -501,20 +501,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 27ed32864333..7b3887538725 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -512,20 +512,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b6d1bf639c20..16af07861b15 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -327,20 +327,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 43499cbab24c..d6e500679fbf 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -486,20 +486,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 5a1e2b6e0e26..0883964fc3c4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -496,20 +496,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b007109480f4..845e9ddfb363 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -496,20 +496,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index eef1920221c1..42151d4999da 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -510,20 +510,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index c06d7888436d..dd3dca885306 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -581,20 +581,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e3854d6d57fa..e2be42dcec21 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -531,20 +531,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index bb1bdc37e256..752162a0755d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -748,20 +748,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds def _get_add_time_ids( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 7003475419cd..fcbe941b3c66 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -469,20 +469,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if isinstance(self.unet.encoder_hid_proj, ImageProjection): - # IP-Adapter + if self.image_encoder.config.output_hidden_states: + image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) - else: - # IP-Adapter Plus - image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[ - -2 - ] - uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds, uncond_image_embeds + + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( From 87aa9d5765162ed37396c7ad0c6402a05d543bd1 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 30 Nov 2023 02:29:21 +0000 Subject: [PATCH 06/16] Refactor PerceiverAttention --- src/diffusers/loaders/unet.py | 24 +++++- src/diffusers/models/attention_processor.py | 69 +++++++++++++-- src/diffusers/models/embeddings.py | 95 ++++----------------- 3 files changed, 101 insertions(+), 87 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 0daae8bbf29e..e52a42233a6f 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import defaultdict +from collections import defaultdict, OrderedDict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -748,18 +748,34 @@ def _load_ip_adapter_weights(self, state_dict): embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] hidden_dims = state_dict["image_proj"]["latents"].shape[2] - num_heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 + heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 image_projection = Resampler( embed_dims=embed_dims, output_dims=output_dims, hidden_dims=hidden_dims, - num_heads=num_heads, + heads=heads, num_queries=num_image_text_embeds, ) image_proj_state_dict = state_dict["image_proj"] - image_projection.load_state_dict(image_proj_state_dict) + new_sd = OrderedDict() + for k, v in image_proj_state_dict.items(): + if "norm1" in k: + new_sd[k.replace("norm1", "norm_cross")] = v + elif "norm2" in k: + new_sd[k.replace("norm2", "layer_norm")] = v + elif "to_kv" in k: + v_chunk = v.chunk(2, dim=0) + new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] + new_sd[k.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in k: + new_sd[k.replace("to_out", "to_out.0")] = v + else: + new_sd[k] = v + + image_projection.load_state_dict(new_sd) + del image_proj_state_dict self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21eb3a32dc09..db7233a616ff 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -84,6 +84,13 @@ class Attention(nn.Module): processor (`AttnProcessor`, *optional*, defaults to `None`): The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and `AttnProcessor` otherwise. + query_layer_norm (`bool`, defaults to `False`): + Set to `True` to use layer norm for the query. + scale_qk_factor (`float`, *optional*, defaults to `None`): + A factor to scale the query and key by. If `None`, defaults to `1 / math.sqrt(query.size(-1))` + if `scale_qk` is `True`. + concat_kv_input (`bool`, defaults to `False`): + Set to `True` to concatenate the hidden_states and encoder_hidden_states for kv inputs. """ def __init__( @@ -109,6 +116,9 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, + query_layer_norm: bool = False, + scale_qk_factor: Optional[float] = None, + concat_kv_input: bool = False, ): super().__init__() self.inner_dim = dim_head * heads @@ -118,13 +128,19 @@ def __init__( self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout + self.concat_kv_input = concat_kv_input # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + if scale_qk_factor is not None: + self.scale = scale_qk_factor + elif self.scale_qk: + self.scale = dim_head**-0.5 + else: + self.scale = 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -150,6 +166,11 @@ def __init__( else: self.spatial_norm = None + if query_layer_norm: + self.layer_norm = nn.LayerNorm(query_dim) + else: + self.layer_norm = None + if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": @@ -726,6 +747,9 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if attn.layer_norm is not None: + hidden_states = attn.layer_norm(hidden_states) + query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: @@ -733,6 +757,9 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + if attn.concat_kv_input: + encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) + key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) @@ -986,7 +1013,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) @@ -1127,6 +1154,9 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if attn.layer_norm is not None: + hidden_states = attn.layer_norm(hidden_states) + query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: @@ -1134,6 +1164,9 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + if attn.concat_kv_input: + encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) + key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) @@ -1207,6 +1240,9 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if attn.layer_norm is not None: + hidden_states = attn.layer_norm(hidden_states) + args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states, *args) @@ -1215,6 +1251,9 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + if attn.concat_kv_input: + encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) + key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) @@ -1229,7 +1268,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -1461,7 +1500,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -1517,6 +1556,9 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if attn.layer_norm is not None: + hidden_states = attn.layer_norm(hidden_states) + query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) @@ -1526,6 +1568,9 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + if attn.concat_kv_input: + encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) + key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) @@ -2031,6 +2076,9 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if attn.layer_norm is not None: + hidden_states = attn.layer_norm(hidden_states) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -2038,6 +2086,9 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + if attn.concat_kv_input: + encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) + # split hidden states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( @@ -2151,6 +2202,9 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if attn.layer_norm is not None: + hidden_states = attn.layer_norm(hidden_states) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -2158,6 +2212,9 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + if attn.concat_kv_input: + encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) + # split hidden states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( @@ -2179,7 +2236,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -2195,7 +2252,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 9184cbab5543..6d735d00dd6f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -21,6 +21,7 @@ from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .lora import LoRACompatibleLinear +from .attention_processor import Attention def get_timestep_embedding( @@ -792,76 +793,6 @@ def forward(self, caption, force_drop_ids=None): return hidden_states -class PerceiverAttention(nn.Module): - """PerceiverAttention of IP-Adapter Plus. - - Args: - ---- - embed_dims (int): The feature dimension. - head_dims (int): The number of head channels. Defaults to 64. - num_heads (int): Parallel attention heads. Defaults to 16. - """ - - def __init__(self, embed_dims: int, head_dims=64, num_heads: int = 16) -> None: - super().__init__() - self.head_dims = head_dims - self.num_heads = num_heads - inner_dim = head_dims * num_heads - - self.norm1 = nn.LayerNorm(embed_dims) - self.norm2 = nn.LayerNorm(embed_dims) - - self.to_q = nn.Linear(embed_dims, inner_dim, bias=False) - self.to_kv = nn.Linear(embed_dims, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, embed_dims, bias=False) - - def _reshape_tensor(self, x, heads) -> torch.Tensor: - """Reshape tensor.""" - bs, length, _ = x.shape - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, length, heads, -1) - # (bs, length, n_heads, dim_per_head) --> - # (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> - # (bs*n_heads, length, dim_per_head) - return x.reshape(bs, heads, length, -1) - - def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - ---- - x (torch.Tensor): image features - shape (b, n1, D) - latents (torch.Tensor): latent features - shape (b, n2, D). - """ - x = self.norm1(x) - latents = self.norm2(latents) - - b, len_latents, _ = latents.shape - - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - - q = self._reshape_tensor(q, self.num_heads) - k = self._reshape_tensor(k, self.num_heads) - v = self._reshape_tensor(v, self.num_heads) - - # attention - scale = 1 / math.sqrt(math.sqrt(self.head_dims)) - # More stable with f16 than dividing afterwards - weight = (q * scale) @ (k * scale).transpose(-2, -1) - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v - - out = out.permute(0, 2, 1, 3).reshape(b, len_latents, -1) - - return self.to_out(out) - - class Resampler(nn.Module): """Resampler of IP-Adapter Plus. @@ -873,8 +804,8 @@ class Resampler(nn.Module): `unet.config.cross_attention_dim`. Defaults to 1024. hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults to 8. - head_dims (int): The number of head channels. Defaults to 64. - num_heads (int): Parallel attention heads. Defaults to 16. + dim_head (int): The number of head channels. Defaults to 64. + heads (int): Parallel attention heads. Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. @@ -886,13 +817,12 @@ def __init__( output_dims: int = 1024, hidden_dims: int = 1280, depth: int = 4, - head_dims: int = 64, - num_heads: int = 16, + dim_head: int = 64, + heads: int = 16, num_queries: int = 8, ffn_ratio: float = 4, ) -> None: super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) self.proj_in = nn.Linear(embed_dims, hidden_dims) @@ -901,11 +831,22 @@ def __init__( self.norm_out = nn.LayerNorm(output_dims) self.layers = nn.ModuleList([]) + scale_qk_factor = 1 / math.sqrt(math.sqrt(dim_head)) for _ in range(depth): self.layers.append( nn.ModuleList( [ - PerceiverAttention(embed_dims=hidden_dims, head_dims=head_dims, num_heads=num_heads), + Attention( + query_dim=hidden_dims, + dim_head=dim_head, + heads=heads, + query_layer_norm=True, + cross_attention_norm="layer_norm", + scale_qk_factor=scale_qk_factor, + residual_connection=True, + out_bias=False, + concat_kv_input=True, + ), self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio), ] ) @@ -937,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj_in(x) for attn, ff in self.layers: - latents = attn(x, latents) + latents + latents = attn(latents, x) latents = ff(latents) + latents latents = self.proj_out(latents) From 775487feedba04f3c9a4c283196e41f904f498d3 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 30 Nov 2023 02:35:52 +0000 Subject: [PATCH 07/16] format --- src/diffusers/loaders/unet.py | 2 +- src/diffusers/models/embeddings.py | 2 +- src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py | 2 +- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 2 +- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 2 +- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index e52a42233a6f..32659d0f8307 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import defaultdict, OrderedDict +from collections import OrderedDict, defaultdict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 6d735d00dd6f..0baf8de55abc 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -20,8 +20,8 @@ from ..utils import USE_PEFT_BACKEND from .activations import get_activation -from .lora import LoRACompatibleLinear from .attention_processor import Attention +from .lora import LoRACompatibleLinear def get_timestep_embedding( diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index f425d6eb75d2..52b8f0e857b1 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 7b3887538725..21ae6fe6d66e 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 16af07861b15..2bed12e48e4a 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -22,7 +22,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index d6e500679fbf..7c0ba64036c1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0883964fc3c4..c9f812f47fdf 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -37,7 +37,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 845e9ddfb363..f185c8d8e666 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 42151d4999da..f7a1380ab1c1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index dd3dca885306..33e81c9e2c71 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e2be42dcec21..410fcfde41f1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -31,7 +31,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 752162a0755d..2ad5e59fd09c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,7 +32,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index fcbe941b3c66..ab510be2fc56 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -33,7 +33,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, From 69ea7053b51c1bf2168d15c55b0b9266e839d86a Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 30 Nov 2023 05:16:47 +0000 Subject: [PATCH 08/16] fix test and refactor PerceiverAttention --- src/diffusers/loaders/unet.py | 4 ++++ src/diffusers/models/attention_processor.py | 21 ++++++------------- src/diffusers/models/embeddings.py | 2 -- tests/models/test_models_unet_2d_condition.py | 18 ++++++++++++++-- .../test_ip_adapter_plus_stable_diffusion.py | 12 +++++------ 5 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 32659d0f8307..e711f4ac3458 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -679,6 +679,10 @@ def _load_ip_adapter_weights(self, state_dict): # IP-Adapter Plus num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] + # Set encoder_hid_proj after loading ip_adapter weights, + # because `Resampler` also has `attn_processors`. + self.encoder_hid_proj = None + # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index db7233a616ff..c2b95d1a53c5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -86,9 +86,6 @@ class Attention(nn.Module): `AttnProcessor` otherwise. query_layer_norm (`bool`, defaults to `False`): Set to `True` to use layer norm for the query. - scale_qk_factor (`float`, *optional*, defaults to `None`): - A factor to scale the query and key by. If `None`, defaults to `1 / math.sqrt(query.size(-1))` - if `scale_qk` is `True`. concat_kv_input (`bool`, defaults to `False`): Set to `True` to concatenate the hidden_states and encoder_hidden_states for kv inputs. """ @@ -117,7 +114,6 @@ def __init__( _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, query_layer_norm: bool = False, - scale_qk_factor: Optional[float] = None, concat_kv_input: bool = False, ): super().__init__() @@ -135,12 +131,7 @@ def __init__( self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk - if scale_qk_factor is not None: - self.scale = scale_qk_factor - elif self.scale_qk: - self.scale = dim_head**-0.5 - else: - self.scale = 1.0 + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -1013,7 +1004,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) @@ -1268,7 +1259,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -1500,7 +1491,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -2236,7 +2227,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -2252,7 +2243,7 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=attn.scale + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0baf8de55abc..71041ee6478b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -831,7 +831,6 @@ def __init__( self.norm_out = nn.LayerNorm(output_dims) self.layers = nn.ModuleList([]) - scale_qk_factor = 1 / math.sqrt(math.sqrt(dim_head)) for _ in range(depth): self.layers.append( nn.ModuleList( @@ -842,7 +841,6 @@ def __init__( heads=heads, query_layer_norm=True, cross_attention_norm="layer_norm", - scale_qk_factor=scale_qk_factor, residual_connection=True, out_bias=False, concat_kv_input=True, diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 62699ba8e141..d6ac16b15d6b 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -18,6 +18,7 @@ import os import tempfile import unittest +from collections import OrderedDict import torch from parameterized import parameterized @@ -128,10 +129,23 @@ def create_ip_adapter_plus_state_dict(model): # "image_proj" (ImageProjection layer weights) cross_attention_dim = model.config["cross_attention_dim"] image_projection = Resampler( - embed_dims=cross_attention_dim, output_dims=cross_attention_dim, hidden_dims=32, num_heads=2, num_queries=4 + embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4 ) - ip_image_projection_state_dict = image_projection.state_dict() + ip_image_projection_state_dict = OrderedDict() + for k, v in image_projection.state_dict().items(): + if "norm_cross" in k: + ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v + elif "layer_norm" in k: + ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v + elif "to_k" in k: + ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0) + elif "to_v" in k: + continue + elif "to_out.0" in k: + ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v + else: + ip_image_projection_state_dict[k] = v ip_state_dict = {} ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py index cdf6c86d9c7f..5eb8b18ad9e7 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py @@ -119,7 +119,7 @@ def test_text_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.3015, 0.2615, 0.2200, 0.2725, 0.2510, 0.2021, 0.2498, 0.2415, 0.2131]) + expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -138,7 +138,7 @@ def test_image_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.3518, 0.2554, 0.2495, 0.2363, 0.1836, 0.3823, 0.1414, 0.1868, 0.5386]) + expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -157,7 +157,7 @@ def test_inpainting(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.2756, 0.2422, 0.2214, 0.2346, 0.2102, 0.2060, 0.2188, 0.2043, 0.1941]) + expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -186,7 +186,7 @@ def test_text_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0587, 0.0567, 0.0454, 0.0537, 0.0553, 0.0518, 0.0494, 0.0535, 0.0497]) + expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -211,7 +211,7 @@ def test_image_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0711, 0.0700, 0.0734, 0.0758, 0.0742, 0.0688, 0.0751, 0.0827, 0.0851]) + expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -237,6 +237,6 @@ def test_inpainting_sdxl(self): image_slice = images[0, :3, :3, -1].flatten() image_slice.tolist() - expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1441, 0.1470, 0.1480, 0.1448, 0.1481, 0.1494]) + expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) From 59012a68b2bf1095c149b72c8982a24c7423c383 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 30 Nov 2023 05:49:36 +0000 Subject: [PATCH 09/16] generic encode_image --- src/diffusers/loaders/ip_adapter.py | 4 ---- .../alt_diffusion/pipeline_alt_diffusion.py | 13 ++++++++----- .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 13 ++++++++----- .../pipelines/animatediff/pipeline_animatediff.py | 13 ++++++++----- .../pipelines/controlnet/pipeline_controlnet.py | 13 ++++++++----- .../controlnet/pipeline_controlnet_sd_xl.py | 13 ++++++++----- .../stable_diffusion/pipeline_stable_diffusion.py | 13 ++++++++----- .../pipeline_stable_diffusion_img2img.py | 13 ++++++++----- .../pipeline_stable_diffusion_inpaint.py | 13 ++++++++----- .../pipeline_stable_diffusion_xl.py | 13 ++++++++----- .../pipeline_stable_diffusion_xl_img2img.py | 13 ++++++++----- .../pipeline_stable_diffusion_xl_inpaint.py | 13 ++++++++----- 12 files changed, 88 insertions(+), 59 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 85defe44002e..32c558554be2 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -148,10 +148,6 @@ def load_ip_adapter( if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: self.feature_extractor = CLIPImageProcessor() - if "proj.weight" not in state_dict["image_proj"]: - # IP-Adapter Plus - self.image_encoder.config.output_hidden_states = True - # load ip-adapter into unet self.unet._load_ip_adapter_weights(state_dict) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 52b8f0e857b1..faf1496384d0 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -494,15 +494,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -886,7 +886,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 21ae6fe6d66e..34cde3d51f4f 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -505,15 +505,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -930,7 +930,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 2bed12e48e4a..6bc825c42876 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -22,7 +22,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( @@ -320,15 +320,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -662,7 +662,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) if do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 7c0ba64036c1..fa7795725a45 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -479,15 +479,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -1078,7 +1078,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c9f812f47fdf..4e3fe42b29a1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -37,7 +37,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -489,15 +489,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -1180,7 +1180,10 @@ def __call__( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f185c8d8e666..7d5c3c9957c0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -489,15 +489,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -882,7 +882,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f7a1380ab1c1..ca60a5b3796a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -503,15 +503,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -934,7 +934,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 33e81c9e2c71..a118c9789a29 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers @@ -574,15 +574,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -1114,7 +1114,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 410fcfde41f1..3d11471c502e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -31,7 +31,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -524,15 +524,15 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -1098,7 +1098,10 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 2ad5e59fd09c..38b9dd4755d0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,7 +32,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -741,15 +741,15 @@ def prepare_latents( return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -1270,7 +1270,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index ab510be2fc56..06708acb8756 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -33,7 +33,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -462,15 +462,15 @@ def disable_vae_tiling(self): self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - if self.image_encoder.config.output_hidden_states: - image_enc_hidden_states = self.image_encoder(image).hidden_states[-2] + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True @@ -1579,7 +1579,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) From 90096f20c1a85c28d3e9cef19deaee45195a0957 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 30 Nov 2023 08:04:20 +0000 Subject: [PATCH 10/16] keep attention implementation --- src/diffusers/loaders/unet.py | 17 ++++++- src/diffusers/models/attention_processor.py | 48 ------------------- src/diffusers/models/embeddings.py | 15 +++--- tests/models/test_models_unet_2d_condition.py | 25 ++++++++++ 4 files changed, 49 insertions(+), 56 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index e711f4ac3458..18fe2b4e073e 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -765,10 +765,23 @@ def _load_ip_adapter_weights(self, state_dict): image_proj_state_dict = state_dict["image_proj"] new_sd = OrderedDict() for k, v in image_proj_state_dict.items(): + if "0.to" in k: + k = k.replace("0.to", "2.to") + elif "1.0.weight" in k: + k = k.replace("1.0.weight", "3.0.weight") + elif "1.0.bias" in k: + k = k.replace("1.0.bias", "3.0.bias") + elif "1.0.weight" in k: + k = k.replace("1.0.weight", "3.0.weight") + elif "1.1.weight" in k: + k = k.replace("1.1.weight", "3.1.weight") + elif "1.3.weight" in k: + k = k.replace("1.3.weight", "3.3.weight") + if "norm1" in k: - new_sd[k.replace("norm1", "norm_cross")] = v + new_sd[k.replace("0.norm1", "0")] = v elif "norm2" in k: - new_sd[k.replace("norm2", "layer_norm")] = v + new_sd[k.replace("0.norm2", "1")] = v elif "to_kv" in k: v_chunk = v.chunk(2, dim=0) new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c2b95d1a53c5..21eb3a32dc09 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -84,10 +84,6 @@ class Attention(nn.Module): processor (`AttnProcessor`, *optional*, defaults to `None`): The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and `AttnProcessor` otherwise. - query_layer_norm (`bool`, defaults to `False`): - Set to `True` to use layer norm for the query. - concat_kv_input (`bool`, defaults to `False`): - Set to `True` to concatenate the hidden_states and encoder_hidden_states for kv inputs. """ def __init__( @@ -113,8 +109,6 @@ def __init__( residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, - query_layer_norm: bool = False, - concat_kv_input: bool = False, ): super().__init__() self.inner_dim = dim_head * heads @@ -124,7 +118,6 @@ def __init__( self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout - self.concat_kv_input = concat_kv_input # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -157,11 +150,6 @@ def __init__( else: self.spatial_norm = None - if query_layer_norm: - self.layer_norm = nn.LayerNorm(query_dim) - else: - self.layer_norm = None - if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": @@ -738,9 +726,6 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if attn.layer_norm is not None: - hidden_states = attn.layer_norm(hidden_states) - query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: @@ -748,9 +733,6 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - if attn.concat_kv_input: - encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) - key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) @@ -1145,9 +1127,6 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if attn.layer_norm is not None: - hidden_states = attn.layer_norm(hidden_states) - query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: @@ -1155,9 +1134,6 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - if attn.concat_kv_input: - encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) - key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) @@ -1231,9 +1207,6 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if attn.layer_norm is not None: - hidden_states = attn.layer_norm(hidden_states) - args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states, *args) @@ -1242,9 +1215,6 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - if attn.concat_kv_input: - encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) - key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) @@ -1547,9 +1517,6 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if attn.layer_norm is not None: - hidden_states = attn.layer_norm(hidden_states) - query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) @@ -1559,9 +1526,6 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - if attn.concat_kv_input: - encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) - key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) @@ -2067,9 +2031,6 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if attn.layer_norm is not None: - hidden_states = attn.layer_norm(hidden_states) - query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -2077,9 +2038,6 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - if attn.concat_kv_input: - encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) - # split hidden states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( @@ -2193,9 +2151,6 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if attn.layer_norm is not None: - hidden_states = attn.layer_norm(hidden_states) - query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -2203,9 +2158,6 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - if attn.concat_kv_input: - encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2) - # split hidden states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 71041ee6478b..506167fb78c1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -835,15 +835,13 @@ def __init__( self.layers.append( nn.ModuleList( [ + nn.LayerNorm(hidden_dims), + nn.LayerNorm(hidden_dims), Attention( query_dim=hidden_dims, dim_head=dim_head, heads=heads, - query_layer_norm=True, - cross_attention_norm="layer_norm", - residual_connection=True, out_bias=False, - concat_kv_input=True, ), self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio), ] @@ -875,8 +873,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj_in(x) - for attn, ff in self.layers: - latents = attn(latents, x) + for ln0, ln1, attn, ff in self.layers: + residual = latents + + encoder_hidden_states = ln0(x) + latents = ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = attn(latents, encoder_hidden_states) + residual latents = ff(latents) + latents latents = self.proj_out(latents) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index d6ac16b15d6b..4596ba9bff88 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -134,6 +134,31 @@ def create_ip_adapter_plus_state_dict(model): ip_image_projection_state_dict = OrderedDict() for k, v in image_projection.state_dict().items(): + if "2.to" in k: + k = k.replace("2.to", "0.to") + elif "3.0.weight" in k: + k = k.replace("3.0.weight", "1.0.weight") + elif "3.0.bias" in k: + k = k.replace("3.0.bias", "1.0.bias") + elif "3.0.weight" in k: + k = k.replace("3.0.weight", "1.0.weight") + elif "3.1.weight" in k: + k = k.replace("3.1.weight", "1.1.weight") + elif "3.3.weight" in k: + k = k.replace("3.3.weight", "1.3.weight") + elif "layers.0.0" in k: + k = k.replace("layers.0.0", "layers.0.0.norm1") + elif "layers.0.1" in k: + k = k.replace("layers.0.1", "layers.0.0.norm2") + elif "layers.1.0" in k: + k = k.replace("layers.1.0", "layers.1.0.norm1") + elif "layers.1.1" in k: + k = k.replace("layers.1.1", "layers.1.0.norm2") + elif "layers.2.0" in k: + k = k.replace("layers.2.0", "layers.2.0.norm1") + elif "layers.2.1" in k: + k = k.replace("layers.2.1", "layers.2.0.norm2") + if "norm_cross" in k: ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v elif "layer_norm" in k: From fe7d2326662d6e5097b2b38e3a6dcfe682ec0bb2 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 1 Dec 2023 02:46:41 +0000 Subject: [PATCH 11/16] merge tests --- .../test_ip_adapter_plus_stable_diffusion.py | 242 ------------------ .../test_ip_adapter_stable_diffusion.py | 114 ++++++++- 2 files changed, 108 insertions(+), 248 deletions(-) delete mode 100644 tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py deleted file mode 100644 index 5eb8b18ad9e7..000000000000 --- a/tests/pipelines/ip_adapters/test_ip_adapter_plus_stable_diffusion.py +++ /dev/null @@ -1,242 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import unittest - -import numpy as np -import torch -from transformers import ( - CLIPImageProcessor, - CLIPVisionModelWithProjection, -) - -from diffusers import ( - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLPipeline, -) -from diffusers.utils import load_image -from diffusers.utils.testing_utils import ( - enable_full_determinism, - require_torch_gpu, - slow, - torch_device, -) - - -enable_full_determinism() - - -class IPAdapterNightlyTestsMixin(unittest.TestCase): - dtype = torch.float16 - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def get_image_encoder(self, repo_id, subfolder): - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - repo_id, subfolder=subfolder, torch_dtype=self.dtype - ).to(torch_device) - return image_encoder - - def get_image_processor(self, repo_id): - image_processor = CLIPImageProcessor.from_pretrained(repo_id) - return image_processor - - def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False): - image = load_image( - "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" - ) - if for_sdxl: - image = image.resize((1024, 1024)) - - input_kwargs = { - "prompt": "best quality, high quality", - "negative_prompt": "monochrome, lowres, bad anatomy, worst quality, low quality", - "num_inference_steps": 5, - "generator": torch.Generator(device="cpu").manual_seed(33), - "ip_adapter_image": image, - "output_type": "np", - } - if for_image_to_image: - image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg") - ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png") - - if for_sdxl: - image = image.resize((1024, 1024)) - ip_image = ip_image.resize((1024, 1024)) - - input_kwargs.update({"image": image, "ip_adapter_image": ip_image}) - - elif for_inpainting: - image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png") - mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png") - ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png") - - if for_sdxl: - image = image.resize((1024, 1024)) - mask = mask.resize((1024, 1024)) - ip_image = ip_image.resize((1024, 1024)) - - input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image}) - - return input_kwargs - - -@slow -@require_torch_gpu -class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): - def test_text_to_image(self): - image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") - pipeline = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - image_encoder=image_encoder, - safety_checker=None, - torch_dtype=self.dtype, - ) - pipeline.to(torch_device) - pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") - - inputs = self.get_dummy_inputs() - images = pipeline(**inputs).images - image_slice = images[0, :3, :3, -1].flatten() - - expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139]) - - assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) - - def test_image_to_image(self): - image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") - pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - image_encoder=image_encoder, - safety_checker=None, - torch_dtype=self.dtype, - ) - pipeline.to(torch_device) - pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") - - inputs = self.get_dummy_inputs(for_image_to_image=True) - images = pipeline(**inputs).images - image_slice = images[0, :3, :3, -1].flatten() - - expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371]) - - assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) - - def test_inpainting(self): - image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") - pipeline = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - image_encoder=image_encoder, - safety_checker=None, - torch_dtype=self.dtype, - ) - pipeline.to(torch_device) - pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") - - inputs = self.get_dummy_inputs(for_inpainting=True) - images = pipeline(**inputs).images - image_slice = images[0, :3, :3, -1].flatten() - - expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934]) - - assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) - - -@slow -@require_torch_gpu -class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): - def test_text_to_image_sdxl(self): - image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") - feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - - pipeline = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - image_encoder=image_encoder, - feature_extractor=feature_extractor, - torch_dtype=self.dtype, - ) - pipeline.to(torch_device) - pipeline.load_ip_adapter( - "h94/IP-Adapter", - subfolder="sdxl_models", - weight_name="ip-adapter-plus_sdxl_vit-h.bin", - ) - - inputs = self.get_dummy_inputs() - images = pipeline(**inputs).images - image_slice = images[0, :3, :3, -1].flatten() - - expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501]) - - assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) - - def test_image_to_image_sdxl(self): - image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") - feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - - pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - image_encoder=image_encoder, - feature_extractor=feature_extractor, - torch_dtype=self.dtype, - ) - pipeline.to(torch_device) - pipeline.load_ip_adapter( - "h94/IP-Adapter", - subfolder="sdxl_models", - weight_name="ip-adapter-plus_sdxl_vit-h.bin", - ) - - inputs = self.get_dummy_inputs(for_image_to_image=True) - images = pipeline(**inputs).images - image_slice = images[0, :3, :3, -1].flatten() - - expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837]) - - assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) - - def test_inpainting_sdxl(self): - image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") - feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - - pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - image_encoder=image_encoder, - feature_extractor=feature_extractor, - torch_dtype=self.dtype, - ) - pipeline.to(torch_device) - pipeline.load_ip_adapter( - "h94/IP-Adapter", - subfolder="sdxl_models", - weight_name="ip-adapter-plus_sdxl_vit-h.bin", - ) - - inputs = self.get_dummy_inputs(for_inpainting=True) - images = pipeline(**inputs).images - image_slice = images[0, :3, :3, -1].flatten() - image_slice.tolist() - - expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494]) - - assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 57eb49013c1f..7c6349ce2600 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -116,7 +116,17 @@ def test_text_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0]) + expected_slice = np.array([0.8110, 0.8843, 0.9326, 0.9224, 0.9878, 1.0, 0.9736, 1.0, 1.0]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -132,7 +142,17 @@ def test_image_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935]) + expected_slice = np.array([0.2253, 0.2251, 0.2219, 0.2312, 0.2236, 0.2434, 0.2275, 0.2575, 0.2805]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -148,7 +168,17 @@ def test_inpainting(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997]) + expected_slice = np.array([0.2700, 0.2388, 0.2202, 0.2304, 0.2095, 0.2097, 0.2173, 0.2058, 0.1987]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -173,7 +203,30 @@ def test_text_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923]) + expected_slice = np.array([0.0965, 0.0956, 0.0849, 0.0908, 0.0944, 0.0927, 0.0888, 0.0929, 0.0920]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -194,7 +247,31 @@ def test_image_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752]) + expected_slice = np.array([0.0652, 0.0698, 0.0723, 0.0744, 0.0699, 0.0636, 0.0784, 0.0803, 0.0742]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -216,6 +293,31 @@ def test_inpainting_sdxl(self): image_slice = images[0, :3, :3, -1].flatten() image_slice.tolist() - expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516]) + expected_slice = np.array([0.1420, 0.1495, 0.1430, 0.1462, 0.1493, 0.1502, 0.1474, 0.1502, 0.1517]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + image_slice.tolist() + + expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) From 11cacbd1f8a75d6fc1fae90f4c104cf832ed7036 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 1 Dec 2023 23:50:28 +0000 Subject: [PATCH 12/16] encode_image backward compatible --- src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py | 2 +- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 2 +- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 2 +- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index faf1496384d0..2121e9b81509 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -494,7 +494,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 34cde3d51f4f..401e6aef82b1 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -505,7 +505,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 6bc825c42876..32a08a0264bc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -320,7 +320,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index fa7795725a45..bf6ef2125446 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -479,7 +479,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4e3fe42b29a1..8c8399809228 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -489,7 +489,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 7d5c3c9957c0..f7f4a16f0aa4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -489,7 +489,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ca60a5b3796a..c80178152a6e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -503,7 +503,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a118c9789a29..375197cc9e4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -574,7 +574,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 3d11471c502e..12d52aa076d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -524,7 +524,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 38b9dd4755d0..729924ec2e20 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -741,7 +741,7 @@ def prepare_latents( return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 06708acb8756..7195b5f2521a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -462,7 +462,7 @@ def disable_vae_tiling(self): self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): From 8cb48a50668abae6bfaa90d94a7c2ae9750287d0 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 1 Dec 2023 23:51:59 +0000 Subject: [PATCH 13/16] code quality --- .../controlnet/pipeline_controlnet_inpaint.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 72c2250dd5ac..279c6e162df6 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -604,11 +604,22 @@ def encode_image(self, image, device, num_images_per_prompt): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): From 24ed74a12e657d0851861267d6f19d50fce260f9 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 1 Dec 2023 23:56:18 +0000 Subject: [PATCH 14/16] fix controlnet inpaint pipeline --- .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 279c6e162df6..71e237ce4e02 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -25,7 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -597,7 +597,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): @@ -1295,7 +1295,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) From e27b6412886e55fddf531db0a04f5780e4e98d45 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 2 Dec 2023 00:25:26 +0000 Subject: [PATCH 15/16] refactor FFN --- src/diffusers/loaders/unet.py | 7 +++---- src/diffusers/models/activations.py | 15 +++++++++------ src/diffusers/models/attention.py | 12 +++++++----- src/diffusers/models/embeddings.py | 7 +++---- tests/models/test_models_unet_2d_condition.py | 8 ++++---- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 18fe2b4e073e..9d559a4b4af8 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -763,6 +763,7 @@ def _load_ip_adapter_weights(self, state_dict): ) image_proj_state_dict = state_dict["image_proj"] + new_sd = OrderedDict() for k, v in image_proj_state_dict.items(): if "0.to" in k: @@ -771,12 +772,10 @@ def _load_ip_adapter_weights(self, state_dict): k = k.replace("1.0.weight", "3.0.weight") elif "1.0.bias" in k: k = k.replace("1.0.bias", "3.0.bias") - elif "1.0.weight" in k: - k = k.replace("1.0.weight", "3.0.weight") elif "1.1.weight" in k: - k = k.replace("1.1.weight", "3.1.weight") + k = k.replace("1.1.weight", "3.1.net.0.proj.weight") elif "1.3.weight" in k: - k = k.replace("1.3.weight", "3.3.weight") + k = k.replace("1.3.weight", "3.1.net.2.weight") if "norm1" in k: new_sd[k.replace("0.norm1", "0")] = v diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 8b75162ba597..47570eca8443 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -55,11 +55,12 @@ class GELU(nn.Module): dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: @@ -81,13 +82,14 @@ class GEGLU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - self.proj = linear_cls(dim_in, dim_out * 2) + self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": @@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f02b5e249eee..08faaaf3e5bf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -501,6 +501,7 @@ class FeedForward(nn.Module): dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__( @@ -511,6 +512,7 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + bias: bool = True, ): super().__init__() inner_dim = int(dim * mult) @@ -518,13 +520,13 @@ def __init__( linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) + act_fn = GELU(dim, inner_dim, bias=bias) if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) + act_fn = GEGLU(dim, inner_dim, bias=bias) elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) # project in @@ -532,7 +534,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(linear_cls(inner_dim, dim_out)) + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 506167fb78c1..d33207bb8235 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -850,12 +850,11 @@ def __init__( def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: """Get feedforward network.""" - inner_dim = int(embed_dims * ffn_ratio) + from .attention import FeedForward # Lazy import to avoid circular import + return nn.Sequential( nn.LayerNorm(embed_dims), - nn.Linear(embed_dims, inner_dim, bias=False), - nn.GELU(), - nn.Linear(inner_dim, embed_dims, bias=False), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 4596ba9bff88..9ccd78f1fe47 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -142,10 +142,10 @@ def create_ip_adapter_plus_state_dict(model): k = k.replace("3.0.bias", "1.0.bias") elif "3.0.weight" in k: k = k.replace("3.0.weight", "1.0.weight") - elif "3.1.weight" in k: - k = k.replace("3.1.weight", "1.1.weight") - elif "3.3.weight" in k: - k = k.replace("3.3.weight", "1.3.weight") + elif "3.1.net.0.proj.weight" in k: + k = k.replace("3.1.net.0.proj.weight", "1.1.weight") + elif "3.net.2.weight" in k: + k = k.replace("3.net.2.weight", "1.3.weight") elif "layers.0.0" in k: k = k.replace("layers.0.0", "layers.0.0.norm1") elif "layers.0.1" in k: From 390f3c01ad67c4a4be206689e9fb16ea96999c62 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 2 Dec 2023 00:28:17 +0000 Subject: [PATCH 16/16] refactor FFN --- src/diffusers/models/embeddings.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d33207bb8235..bdd2930d20f9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -823,6 +823,8 @@ def __init__( ffn_ratio: float = 4, ) -> None: super().__init__() + from .attention import FeedForward # Lazy import to avoid circular import + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) self.proj_in = nn.Linear(embed_dims, hidden_dims) @@ -843,20 +845,14 @@ def __init__( heads=heads, out_bias=False, ), - self._get_ffn(embed_dims=hidden_dims, ffn_ratio=ffn_ratio), + nn.Sequential( + nn.LayerNorm(hidden_dims), + FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ), ] ) ) - def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential: - """Get feedforward network.""" - from .attention import FeedForward # Lazy import to avoid circular import - - return nn.Sequential( - nn.LayerNorm(embed_dims), - FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), - ) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.