diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2cc547be0178..695a22d955da 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib import os import re from collections import defaultdict @@ -32,15 +31,16 @@ from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, + USE_PEFT_BACKEND, _get_model_file, convert_state_dict_to_diffusers, convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, is_accelerate_available, is_omegaconf_available, - is_peft_available, is_transformers_available, logging, recurse_remove_peft_layers, @@ -72,19 +72,6 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" - -# Below should be `True` if the current version of `peft` and `transformers` are compatible with -# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are -# available. -# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. -_required_peft_version = is_peft_available() and version.parse( - version.parse(importlib.metadata.version("peft")).base_version -) > version.parse("0.5") -_required_transformers_version = version.parse( - version.parse(importlib.metadata.version("transformers")).base_version -) > version.parse("4.33") - -USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." @@ -413,7 +400,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # fill attn processors lora_layers_list = [] - is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) + is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: @@ -527,6 +514,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict cross_attention_dim=cross_attention_dim, ) attn_processors[key].load_state_dict(value_dict) + elif USE_PEFT_BACKEND: + # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict` + # on the Unet + pass else: raise ValueError( f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." @@ -537,33 +528,36 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # Now we remove any existing hooks to is_model_cpu_offload = False is_sequential_cpu_offload = False - if _pipeline is not None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module): - if hasattr(component, "_hf_hook"): + + # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet` + if not USE_PEFT_BACKEND: + if _pipeline is not None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - # only custom diffusion needs to set attn processors - if is_custom_diffusion: - self.set_attn_processor(attn_processors) + # only custom diffusion needs to set attn processors + if is_custom_diffusion: + self.set_attn_processor(attn_processors) - # set lora layers - for target_module, lora_layer in lora_layers_list: - target_module.set_lora_layer(lora_layer) + # set lora layers + for target_module, lora_layer in lora_layers_list: + target_module.set_lora_layer(lora_layer) - self.to(dtype=self.dtype, device=self.device) + self.to(dtype=self.dtype, device=self.device) - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): is_new_lora_format = all( @@ -686,15 +680,77 @@ def fuse_lora(self, lora_scale=1.0, safe_fusing=False): self.apply(self._fuse_lora_apply) def _fuse_lora_apply(self, module): - if hasattr(module, "_fuse_lora"): - module._fuse_lora(self.lora_scale, self._safe_fusing) + if not USE_PEFT_BACKEND: + if hasattr(module, "_fuse_lora"): + module._fuse_lora(self.lora_scale, self._safe_fusing) + else: + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(module, BaseTunerLayer): + if self.lora_scale != 1.0: + module.scale_layer(self.lora_scale) + module.merge(safe_merge=self._safe_fusing) def unfuse_lora(self): self.apply(self._unfuse_lora_apply) def _unfuse_lora_apply(self, module): - if hasattr(module, "_unfuse_lora"): - module._unfuse_lora() + if not USE_PEFT_BACKEND: + if hasattr(module, "_unfuse_lora"): + module._unfuse_lora() + else: + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(module, BaseTunerLayer): + module.unmerge() + + def set_adapters( + self, + adapter_names: Union[List[str], str], + weights: Optional[Union[List[float], float]] = None, + ): + """ + Sets the adapter layers for the unet. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + weights (`Union[List[float], float]`, *optional*): + The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the + adapters. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for `set_adapters()`.") + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + if weights is None: + weights = [1.0] * len(adapter_names) + elif isinstance(weights, float): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." + ) + + set_weights_and_activate_adapters(self, adapter_names, weights) + + def disable_lora(self): + """ + Disables the active LoRA layers for the unet. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + set_adapter_layers(self, enabled=False) + + def enable_lora(self): + """ + Enables the active LoRA layers for the unet. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + set_adapter_layers(self, enabled=True) def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): @@ -1113,7 +1169,6 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME num_fused_loras = 0 - use_peft_backend = USE_PEFT_BACKEND def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs @@ -1155,6 +1210,7 @@ def load_lora_weights( network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage, + adapter_name=adapter_name, _pipeline=self, ) self.load_lora_into_text_encoder( @@ -1464,7 +1520,40 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter=" return new_state_dict @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None): + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + @classmethod + def load_lora_into_unet( + cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None + ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -1482,6 +1571,9 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -1508,6 +1600,56 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." logger.warn(warn_message) + if USE_PEFT_BACKEND and len(state_dict.keys()) > 0: + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + if adapter_name in getattr(unet, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." + ) + + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if network_alphas is not None: + # The alphas state dict have the same structure as Unet, thus we convert it to peft format using + # `convert_unet_state_dict_to_peft` method. + network_alphas = convert_unet_state_dict_to_peft(network_alphas) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(unet) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + unet.load_attn_procs( state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline ) @@ -1570,7 +1712,7 @@ def load_lora_into_text_encoder( rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - if cls.use_peft_backend: + if USE_PEFT_BACKEND: # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) @@ -1583,6 +1725,7 @@ def load_lora_into_text_encoder( for name, _ in text_encoder_mlp_modules(text_encoder): rank_key_fc1 = f"{name}.fc1.lora_B.weight" rank_key_fc2 = f"{name}.fc2.lora_B.weight" + rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] else: @@ -1606,10 +1749,12 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - if cls.use_peft_backend: + if USE_PEFT_BACKEND: from peft import LoraConfig - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, is_unet=False + ) lora_config = LoraConfig(**lora_config_kwargs) @@ -1617,17 +1762,18 @@ def load_lora_into_text_encoder( if adapter_name is None: adapter_name = get_adapter_name(text_encoder) + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not text_encoder.load_adapter( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, ) + # scale LoRA layers with `lora_scale` scale_lora_layers(text_encoder, weight=lora_scale) - - is_model_cpu_offload = False - is_sequential_cpu_offload = False else: cls._modify_text_encoder( text_encoder, @@ -1699,7 +1845,7 @@ def lora_scale(self) -> float: return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 def _remove_text_encoder_monkey_patch(self): - if self.use_peft_backend: + if USE_PEFT_BACKEND: remove_method = recurse_remove_peft_layers else: remove_method = self._remove_text_encoder_monkey_patch_classmethod @@ -1707,12 +1853,13 @@ def _remove_text_encoder_monkey_patch(self): if hasattr(self, "text_encoder"): remove_method(self.text_encoder) - if self.use_peft_backend: + # In case text encoder have no Lora attached + if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None: del self.text_encoder.peft_config self.text_encoder._hf_peft_config_loaded = None if hasattr(self, "text_encoder_2"): remove_method(self.text_encoder_2) - if self.use_peft_backend: + if USE_PEFT_BACKEND: del self.text_encoder_2.peft_config self.text_encoder_2._hf_peft_config_loaded = None @@ -2088,9 +2235,20 @@ def unload_lora_weights(self): >>> ... ``` """ - for _, module in self.unet.named_modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) + if not USE_PEFT_BACKEND: + if version.parse(__version__) > version.parse("0.23"): + logger.warn( + "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights," + "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." + ) + + for _, module in self.unet.named_modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + else: + recurse_remove_peft_layers(self.unet) + if hasattr(self.unet, "peft_config"): + del self.unet.peft_config # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() @@ -2131,7 +2289,7 @@ def fuse_lora( if fuse_unet: self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) - if self.use_peft_backend: + if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): @@ -2184,9 +2342,16 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True LoRA parameters then it won't have any effect. """ if unfuse_unet: - self.unet.unfuse_lora() + if not USE_PEFT_BACKEND: + self.unet.unfuse_lora() + else: + from peft.tuners.tuners_utils import BaseTunerLayer - if self.use_peft_backend: + for module in self.unet.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer def unfuse_text_encoder_lora(text_encoder): @@ -2219,7 +2384,7 @@ def unfuse_text_encoder_lora(text_encoder): self.num_fused_loras -= 1 - def set_adapter_for_text_encoder( + def set_adapters_for_text_encoder( self, adapter_names: Union[List[str], str], text_encoder: Optional[PreTrainedModel] = None, @@ -2237,7 +2402,7 @@ def set_adapter_for_text_encoder( text_encoder_weights (`List[float]`, *optional*): The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. """ - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") def process_weights(adapter_names, weights): @@ -2270,7 +2435,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder` attribute. """ - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") text_encoder = text_encoder or getattr(self, "text_encoder", None) @@ -2287,13 +2452,146 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` attribute. """ - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") text_encoder = text_encoder or getattr(self, "text_encoder", None) if text_encoder is None: raise ValueError("Text Encoder not found.") set_adapter_layers(self.text_encoder, enabled=True) + def set_adapters( + self, + adapter_names: Union[List[str], str], + adapter_weights: Optional[List[float]] = None, + ): + # Handle the UNET + self.unet.set_adapters(adapter_names, adapter_weights) + + # Handle the Text Encoder + if hasattr(self, "text_encoder"): + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights) + if hasattr(self, "text_encoder_2"): + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights) + + def disable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # Disable unet adapters + self.unet.disable_lora() + + # Disable text encoder adapters + if hasattr(self, "text_encoder"): + self.disable_lora_for_text_encoder(self.text_encoder) + if hasattr(self, "text_encoder_2"): + self.disable_lora_for_text_encoder(self.text_encoder_2) + + def enable_lora(self): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # Enable unet adapters + self.unet.enable_lora() + + # Enable text encoder adapters + if hasattr(self, "text_encoder"): + self.enable_lora_for_text_encoder(self.text_encoder) + if hasattr(self, "text_encoder_2"): + self.enable_lora_for_text_encoder(self.text_encoder_2) + + def get_active_adapters(self) -> List[str]: + """ + Gets the list of the current active adapters. + + Example: + + ```python + from diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + ).to("cuda") + pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") + pipeline.get_active_adapters() + ``` + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + from peft.tuners.tuners_utils import BaseTunerLayer + + active_adapters = [] + + for module in self.unet.modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapters + break + + return active_adapters + + def get_list_adapters(self) -> Dict[str, List[str]]: + """ + Gets the current list of all available adapters in the pipeline. + """ + if not USE_PEFT_BACKEND: + raise ValueError( + "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" + ) + + set_adapters = {} + + if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"): + set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys()) + + if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): + set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) + + if hasattr(self, "unet") and hasattr(self.unet, "peft_config"): + set_adapters["unet"] = list(self.unet.peft_config.keys()) + + return set_adapters + + def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: + """ + Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case + you want to load multiple adapters and free some GPU memory. + + Args: + adapter_names (`List[str]`): + List of adapters to send device to. + device (`Union[torch.device, str, int]`): + Device to send the adapters to. Can be either a torch device, a str or an integer. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + # Handle the UNET + for unet_module in self.unet.modules(): + if isinstance(unet_module, BaseTunerLayer): + for adapter_name in adapter_names: + unet_module.lora_A[adapter_name].to(device) + unet_module.lora_B[adapter_name].to(device) + + # Handle the text encoder + modules_to_process = [] + if hasattr(self, "text_encoder"): + modules_to_process.append(self.text_encoder) + + if hasattr(self, "text_encoder_2"): + modules_to_process.append(self.text_encoder_2) + + for text_encoder in modules_to_process: + # loop over submodules + for text_encoder_module in text_encoder.modules(): + if isinstance(text_encoder_module, BaseTunerLayer): + for adapter_name in adapter_names: + text_encoder_module.lora_A[adapter_name].to(device) + text_encoder_module.lora_B[adapter_name].to(device) + class FromSingleFileMixin: """ @@ -2878,7 +3176,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL""" # Overrride to properly handle the loading and unloading of the additional text encoder. - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + **kwargs, + ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. @@ -2896,6 +3199,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ @@ -2913,7 +3219,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self) + self.load_lora_into_unet( + state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( @@ -2922,6 +3230,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, + adapter_name=adapter_name, _pipeline=self, ) @@ -2933,6 +3242,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, + adapter_name=adapter_name, _pipeline=self, ) @@ -2999,16 +3309,17 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - if self.use_peft_backend: + if USE_PEFT_BACKEND: recurse_remove_peft_layers(self.text_encoder) # TODO: @younesbelkada handle this in transformers side - del self.text_encoder.peft_config - self.text_encoder._hf_peft_config_loaded = None + if getattr(self.text_encoder, "peft_config", None) is not None: + del self.text_encoder.peft_config + self.text_encoder._hf_peft_config_loaded = None recurse_remove_peft_layers(self.text_encoder_2) - - del self.text_encoder_2.peft_config - self.text_encoder_2._hf_peft_config_loaded = None + if getattr(self.text_encoder_2, "peft_config", None) is not None: + del self.text_encoder_2.peft_config + self.text_encoder_2._hf_peft_config_loaded = None else: self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6f5d1da6c6ae..47608005d374 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,6 +17,7 @@ import torch.nn.functional as F from torch import nn +from ..utils import USE_PEFT_BACKEND from ..utils.torch_utils import maybe_allow_in_graph from .activations import get_activation from .attention_processor import Attention @@ -300,6 +301,7 @@ def __init__( super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) @@ -316,14 +318,15 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + self.net.append(linear_cls(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) for module in self.net: - if isinstance(module, (LoRACompatibleLinear, GEGLU)): + if isinstance(module, compatible_cls): hidden_states = module(hidden_states, scale) else: hidden_states = module(hidden_states) @@ -368,7 +371,9 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int): super().__init__() - self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + self.proj = linear_cls(dim_in, dim_out * 2) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": @@ -377,7 +382,8 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, scale: float = 1.0): - hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) + args = () if USE_PEFT_BACKEND else (scale,) + hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7cc3ea2aa07a..9856f3c7739c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from ..utils import deprecate, logging +from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .lora import LoRACompatibleLinear, LoRALinearLayer @@ -137,22 +137,27 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) + if USE_PEFT_BACKEND: + linear_cls = nn.Linear + else: + linear_cls = LoRACompatibleLinear + + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -545,6 +550,8 @@ def __call__( ): residual = hidden_states + args = () if USE_PEFT_BACKEND else (scale,) + if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -562,15 +569,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) @@ -581,7 +588,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1007,15 +1014,20 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = ( + attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states) + ) + value = ( + attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states) + ) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1035,7 +1047,9 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = ( + attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states) + ) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e05092de3d10..d3422c8f58b2 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -18,6 +18,7 @@ import torch from torch import nn +from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .lora import LoRACompatibleLinear @@ -166,8 +167,9 @@ def __init__( cond_proj_dim=None, ): super().__init__() + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear - self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim) + self.linear_1 = linear_cls(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) @@ -180,7 +182,7 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out) + self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 67746ebacef2..7639f75152a5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -32,10 +32,12 @@ DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, + MIN_PEFT_VERSION, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, + check_peft_version, deprecate, is_accelerate_available, is_torch_version, @@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None + _hf_peft_config_loaded = False def __init__(self): super().__init__() @@ -292,6 +295,153 @@ def disable_xformers_memory_efficient_attention(self): """ self.set_use_memory_efficient_attention_xformers(False) + def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: + r""" + Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned + to the adapter to follow the convention of the PEFT library. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT + [documentation](https://huggingface.co/docs/peft). + + Args: + adapter_config (`[~peft.PeftConfig]`): + The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt + methods. + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + from peft import PeftConfig, inject_adapter_in_model + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + elif adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + + # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is + # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. + adapter_config.base_model_name_or_path = None + inject_adapter_in_model(adapter_config, self, adapter_name) + self.set_adapter(adapter_name) + + def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: + """ + Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Args: + adapter_name (Union[str, List[str]])): + The list of adapters to set or the adapter name in case of single adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + missing = set(adapter_name) - set(self.peft_config) + if len(missing) > 0: + raise ValueError( + f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." + f" current loaded adapters are: {list(self.peft_config.keys())}" + ) + + from peft.tuners.tuners_utils import BaseTunerLayer + + _adapters_has_been_set = False + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + # Previous versions of PEFT does not support multi-adapter inference + elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: + raise ValueError( + "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." + " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" + ) + else: + module.active_adapter = adapter_name + _adapters_has_been_set = True + + if not _adapters_has_been_set: + raise ValueError( + "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." + ) + + def disable_adapters(self) -> None: + r""" + Disable all adapters attached to the model and fallback to inference with the base model only. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + # support for older PEFT versions + module.disable_adapters = True + + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the + list of adapters to enable. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=True) + else: + # support for older PEFT versions + module.disable_adapters = False + + def active_adapters(self) -> List[str]: + """ + Gets the current list of active adapters of the model. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + return module.active_adapter + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 3972b438b076..80bf269fc4e3 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F +from ..utils import USE_PEFT_BACKEND from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm @@ -149,12 +150,13 @@ def __init__( self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv = None if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: - conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) + conv = conv_cls(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": @@ -193,12 +195,12 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - if isinstance(self.conv, LoRACompatibleConv): + if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: hidden_states = self.conv(hidden_states, scale) else: hidden_states = self.conv(hidden_states) else: - if isinstance(self.Conv2d_0, LoRACompatibleConv): + if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: hidden_states = self.Conv2d_0(hidden_states, scale) else: hidden_states = self.Conv2d_0(hidden_states) @@ -237,9 +239,10 @@ def __init__( self.padding = padding stride = 2 self.name = name + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv if use_conv: - conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) @@ -255,13 +258,18 @@ def __init__( def forward(self, hidden_states, scale: float = 1.0): assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) + + if not USE_PEFT_BACKEND: + if isinstance(self.conv, LoRACompatibleConv): + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) else: hidden_states = self.conv(hidden_states) @@ -608,6 +616,9 @@ def __init__( self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + if groups_out is None: groups_out = groups @@ -618,13 +629,13 @@ def __init__( else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) + self.time_emb_proj = linear_cls(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) + self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: @@ -641,7 +652,7 @@ def __init__( self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) @@ -667,7 +678,7 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = LoRACompatibleConv( + self.conv_shortcut = conv_cls( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) @@ -708,12 +719,16 @@ def forward(self, input_tensor, temb, scale: float = 1.0): else self.downsample(hidden_states) ) - hidden_states = self.conv1(hidden_states, scale) + hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb, scale)[:, :, None, None] + temb = ( + self.time_emb_proj(temb, scale)[:, :, None, None] + if not USE_PEFT_BACKEND + else self.time_emb_proj(temb)[:, :, None, None] + ) if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb @@ -730,10 +745,12 @@ def forward(self, input_tensor, temb, scale: float = 1.0): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, scale) + hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, scale) + input_tensor = ( + self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) + ) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index e7780a7bca3d..0f00932f3014 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import BaseOutput, deprecate +from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed from .lora import LoRACompatibleConv, LoRACompatibleLinear @@ -100,6 +100,9 @@ def __init__( self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) @@ -139,9 +142,9 @@ def __init__( self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: - self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + self.proj_in = linear_cls(in_channels, inner_dim) else: - self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -197,9 +200,9 @@ def __init__( if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: - self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + self.proj_out = linear_cls(inner_dim, in_channels) else: - self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) @@ -292,13 +295,21 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states, scale=lora_scale) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states, scale=lora_scale) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -334,9 +345,17 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states, scale=lora_scale) + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) else: - hidden_states = self.proj_out(hidden_states, scale=lora_scale) + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f858a7685360..4039fbfcc67a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging +from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -995,6 +995,9 @@ def forward( # 3. down lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None @@ -1094,6 +1097,10 @@ def forward( sample = self.conv_act(sample) sample = self.conv_out(sample) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self) + if not return_dict: return (sample,) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index ba3930f5da59..18518cc3783f 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -25,7 +25,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -304,7 +311,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -432,7 +439,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) @@ -668,6 +675,7 @@ def __call__( # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -689,9 +697,8 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, @@ -700,7 +707,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, + lora_scale=lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 47fa019647d4..404a4277e7cc 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, + USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, @@ -309,7 +310,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -437,7 +438,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index ad0060976440..87259378b8a2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -27,7 +27,14 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput @@ -287,7 +294,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -415,7 +422,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index ef34ad3ee70a..98075a3d7253 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -27,6 +27,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, @@ -317,7 +318,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -445,7 +446,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 640ca0a22e9c..324aa1e0f81c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -28,7 +28,14 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput @@ -438,7 +445,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -566,7 +573,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 41b0d5434386..4a843a4ed883 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -33,6 +33,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, is_invisible_watermark_available, logging, replace_example_docstring, @@ -316,7 +317,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: @@ -458,7 +459,7 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 7f230c2ec058..f555ea49b3ab 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -35,7 +35,7 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -285,7 +285,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: @@ -427,7 +427,7 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index aeffc219674d..849e34153285 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -36,6 +36,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, @@ -328,7 +329,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: @@ -470,7 +471,7 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 6bcbbab135df..d45e35d5cba0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -27,7 +27,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler -from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusionPipelineOutput @@ -308,7 +308,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -436,7 +436,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 68cdbbe78b5a..a9d28144e543 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -25,7 +25,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusionPipelineOutput @@ -297,7 +304,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -425,7 +432,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) @@ -658,6 +665,7 @@ def __call__( # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -679,9 +687,8 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, @@ -690,7 +697,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, + lora_scale=lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index e49e12b92ea3..153efae876cd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -27,7 +27,14 @@ from ...models.attention_processor import Attention from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -332,7 +339,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -460,7 +467,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 95c3a79cf0c5..d73cf769e3ae 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -28,7 +28,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -213,7 +213,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -341,7 +341,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 7126b798feb5..451ef690a759 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -30,6 +30,7 @@ from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, + USE_PEFT_BACKEND, BaseOutput, deprecate, logging, @@ -483,7 +484,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -611,7 +612,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py index f176f08d5d8c..ce7faaed2ab1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py @@ -26,7 +26,14 @@ from ...models.attention import GatedSelfAttentionDense from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -274,7 +281,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -402,7 +409,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py index ba418b4cb3c3..67f3fe0e9448 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py @@ -32,7 +32,7 @@ from ...models.attention import GatedSelfAttentionDense from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -306,7 +306,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -434,7 +434,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 8c180f5224b7..ffbd8246603e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, + USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, @@ -304,7 +305,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -432,7 +433,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index e792eb8f8c12..e185ed588047 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -27,7 +27,7 @@ from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -375,7 +375,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -503,7 +503,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 4b555e0367c6..513c660c30cf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -27,7 +27,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -297,7 +297,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -425,7 +425,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index a5c447792ff5..e0bb9b6e0b14 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -24,7 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LMSDiscreteScheduler -from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -211,7 +211,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -339,7 +339,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index eb3ba4b90a71..2e514a55108c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -26,7 +26,15 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import BaseOutput, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker @@ -267,7 +275,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -395,7 +403,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index e67c04ebcf7c..6c78d190d97f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -24,7 +24,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin -from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -244,7 +244,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -372,7 +372,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 1704e28f0c7f..bac1f83fb336 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -23,7 +23,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -221,7 +228,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -349,7 +356,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py index 256286904804..161f656fee2e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -24,6 +24,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, @@ -258,7 +259,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -386,7 +387,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 6cbea1d1da7e..6d4286a04686 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -37,6 +37,7 @@ from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...utils import ( PIL_INTERPOLATION, + USE_PEFT_BACKEND, BaseOutput, deprecate, logging, @@ -448,7 +449,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -576,7 +577,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 42cc9905c49a..6a78d4da4545 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -24,7 +24,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -244,7 +251,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -372,7 +379,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 8d01e0a0d086..f3aa01ebeebb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -32,7 +32,7 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers -from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -242,7 +242,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -370,7 +370,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 3b12058eda7b..3bce80fdb5b1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -25,7 +25,14 @@ from ...models.embeddings import get_timestep_embedding from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -342,7 +349,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -470,7 +477,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 3ef1994b0cb3..a17a674b7066 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -25,7 +25,14 @@ from ...models.embeddings import get_timestep_embedding from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -296,7 +303,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -424,7 +431,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 61856d16a197..cbfe4e0d3835 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -34,6 +34,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, is_invisible_watermark_available, is_torch_xla_available, logging, @@ -274,7 +275,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: @@ -416,7 +417,7 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) @@ -796,9 +797,8 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ( prompt_embeds, negative_prompt_embeds, @@ -816,7 +816,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, + lora_scale=lora_scale, clip_skip=clip_skip, ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index a5fb134f9913..75f814ca84cd 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -31,6 +31,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, is_invisible_watermark_available, is_torch_xla_available, logging, @@ -281,7 +282,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: @@ -423,7 +424,7 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index a6e0531eae3a..39ec59048f39 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -32,6 +32,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, @@ -430,7 +431,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: @@ -572,7 +573,7 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 2ab3bf00c8fc..54ba178846f4 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, + USE_PEFT_BACKEND, BaseOutput, deprecate, logging, @@ -298,7 +299,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -426,7 +427,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index b32c852481ab..eb73302f8121 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -31,7 +31,14 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -283,7 +290,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: @@ -425,7 +432,7 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 42c00597beee..83c31596940e 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -23,7 +23,14 @@ from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import TextToVideoSDPipelineOutput @@ -224,7 +231,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -352,7 +359,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index c571d3d6bc5e..0d886cb00677 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -24,7 +24,14 @@ from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from . import TextToVideoSDPipelineOutput @@ -286,7 +293,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -414,7 +421,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 7e0b07cc79ef..0d5880ac0d4f 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -18,7 +18,7 @@ from ...models import AutoencoderKL from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.outputs import BaseOutput from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -426,7 +426,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: + if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) @@ -554,7 +554,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4e50bbefe933..2ed3deeb1225 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -31,7 +31,7 @@ ) from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import is_torch_version, logging +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import apply_freeu @@ -1211,6 +1211,9 @@ def forward( # 3. down lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None @@ -1310,6 +1313,10 @@ def forward( sample = self.conv_act(sample) sample = self.conv_out(sample) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self) + if not return_dict: return (sample,) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 128ebb1fb737..b4d6bdab33eb 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -26,9 +26,11 @@ FLAX_WEIGHTS_NAME, HF_MODULES_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + USE_PEFT_BACKEND, WEIGHTS_NAME, ) from .deprecation_utils import deprecate @@ -86,6 +88,7 @@ from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( + check_peft_version, get_adapter_name, get_peft_kwargs, recurse_remove_peft_layers, @@ -95,7 +98,11 @@ unscale_lora_layers, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil -from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft +from .state_dict_utils import ( + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, +) logger = get_logger(__name__) diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 1f51f2c0497b..3023cb476fe0 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -11,13 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home +from packaging import version + +from .import_utils import is_peft_available, is_transformers_available default_cache_path = HUGGINGFACE_HUB_CACHE +MIN_PEFT_VERSION = "0.5.0" + CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" @@ -30,3 +36,16 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] + +# Below should be `True` if the current version of `peft` and `transformers` are compatible with +# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are +# available. +# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. +_required_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version +) > version.parse(MIN_PEFT_VERSION) +_required_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version +) > version.parse("4.33") + +USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 9accf07a137b..efc977518b14 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -15,8 +15,11 @@ PEFT utilities: Utilities related to peft library """ import collections +import importlib -from .import_utils import is_torch_available +from packaging import version + +from .import_utils import is_peft_available, is_torch_available def recurse_remove_peft_layers(model): @@ -53,7 +56,6 @@ def recurse_remove_peft_layers(model): module.padding, module.dilation, module.groups, - module.bias, ).to(module.weight.device) new_module.weight = module.weight @@ -106,10 +108,11 @@ def unscale_lora_layers(model): module.unscale_layer() -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] + if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] @@ -118,13 +121,22 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} - if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1: - # get the alpha occuring the most number of times - lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] - - # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` - alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) - alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} + if network_alpha_dict is not None: + if len(set(network_alpha_dict.values())) > 1: + # get the alpha occuring the most number of times + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] + + # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) + if is_unet: + alpha_pattern = { + ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v + for k, v in alpha_pattern.items() + } + else: + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} + else: + lora_alpha = set(network_alpha_dict.values()).pop() # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) @@ -155,9 +167,9 @@ def set_adapter_layers(model, enabled=True): if isinstance(module, BaseTunerLayer): # The recent version of PEFT needs to call `enable_adapters` instead if hasattr(module, "enable_adapters"): - module.enable_adapters(enabled=False) + module.enable_adapters(enabled=enabled) else: - module.disable_adapters = True + module.disable_adapters = not enabled def set_weights_and_activate_adapters(model, adapter_names, weights): @@ -182,3 +194,23 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): module.set_adapter(adapter_names) else: module.active_adapter = adapter_names + + +def check_peft_version(min_version: str) -> None: + r""" + Checks if the version of PEFT is compatible. + + Args: + version (`str`): + The version of PEFT to check against. + """ + if not is_peft_available(): + raise ValueError("PEFT is not installed. Please install it with `pip install peft`") + + is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version) + + if not is_peft_version_compatible: + raise ValueError( + f"The version of PEFT you are using is not compatible, please use a version that is greater" + f" than {min_version}" + ) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 1164dba18042..777c611f7150 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -28,6 +28,22 @@ class StateDictType(enum.Enum): DIFFUSERS = "diffusers" +# We need to define a proper mapping for Unet since it uses different output keys than text encoder +# e.g. to_q_lora -> q_proj / to_q +UNET_TO_DIFFUSERS = { + ".to_out_lora.up": ".to_out.0.lora_B", + ".to_out_lora.down": ".to_out.0.lora_A", + ".to_q_lora.down": ".to_q.lora_A", + ".to_q_lora.up": ".to_q.lora_B", + ".to_k_lora.down": ".to_k.lora_A", + ".to_k_lora.up": ".to_k.lora_B", + ".to_v_lora.down": ".to_v.lora_A", + ".to_v_lora.up": ".to_v.lora_B", + ".lora.up": ".lora_B", + ".lora.down": ".lora_A", +} + + DIFFUSERS_TO_PEFT = { ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", @@ -50,6 +66,8 @@ class StateDictType(enum.Enum): ".to_v_lora.down": ".v_proj.lora_A", ".to_out_lora.up": ".out_proj.lora_B", ".to_out_lora.down": ".out_proj.lora_A", + ".lora_linear_layer.up": ".lora_B", + ".lora_linear_layer.down": ".lora_A", } PEFT_TO_DIFFUSERS = { @@ -84,6 +102,10 @@ class StateDictType(enum.Enum): StateDictType.PEFT: PEFT_TO_DIFFUSERS, } +KEYS_TO_ALWAYS_REPLACE = { + ".processor.": ".", +} + def convert_state_dict(state_dict, mapping): r""" @@ -103,6 +125,12 @@ def convert_state_dict(state_dict, mapping): """ converted_state_dict = {} for k, v in state_dict.items(): + # First, filter out the keys that we always want to replace + for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): + if pattern in k: + new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] + k = k.replace(pattern, new_pattern) + for pattern in mapping.keys(): if pattern in k: new_pattern = mapping[pattern] @@ -184,3 +212,11 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] return convert_state_dict(state_dict, mapping) + + +def convert_unet_state_dict_to_peft(state_dict): + r""" + Converts a state dict from UNet format to diffusers format - i.e. by removing some keys + """ + mapping = UNET_TO_DIFFUSERS + return convert_state_dict(state_dict, mapping) diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index b99b8803d031..047cdddfa95a 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -673,6 +673,7 @@ def test_lora_save_load_with_xformers(self): self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) +@deprecate_after_peft_backend class SDXInpaintLoraMixinTests(unittest.TestCase): def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True): # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched @@ -1387,6 +1388,7 @@ def test_save_load_fused_lora_modules(self): ), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth." +@deprecate_after_peft_backend class UNet2DConditionLoRAModelTests(unittest.TestCase): model_class = UNet2DConditionModel main_input_name = "sample" @@ -1635,6 +1637,7 @@ def test_lora_xformers_on_off(self, expected_max_diff=6e-4): assert max_diff_off_sample < expected_max_diff +@deprecate_after_peft_backend class UNet3DConditionModelTests(unittest.TestCase): model_class = UNet3DConditionModel main_input_name = "sample" @@ -1877,6 +1880,7 @@ def test_lora_xformers_on_off(self): @slow +@deprecate_after_peft_backend @require_torch_gpu class LoraIntegrationTests(unittest.TestCase): def test_dreambooth_old_format(self): diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 1862437fce88..198ff53340c8 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -12,21 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os import tempfile +import time import unittest import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from huggingface_hub.repocard import RepoCard from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( AutoencoderKL, + ControlNetModel, DDIMScheduler, + DiffusionPipeline, EulerDiscreteScheduler, StableDiffusionPipeline, + StableDiffusionXLControlNetPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -35,9 +41,20 @@ LoRAAttnProcessor, LoRAAttnProcessor2_0, ) -from diffusers.utils.import_utils import is_peft_available -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, require_torch_gpu, slow +from diffusers.utils.import_utils import is_accelerate_available, is_peft_available +from diffusers.utils.testing_utils import ( + floats_tensor, + load_image, + nightly, + require_peft_backend, + require_torch_gpu, + slow, + torch_device, +) + +if is_accelerate_available(): + from accelerate.utils import release_memory if is_peft_available(): from peft import LoraConfig @@ -45,6 +62,18 @@ from peft.utils import get_peft_model_state_dict +def state_dicts_almost_equal(sd1, sd2): + sd1 = dict(sorted(sd1.items())) + sd2 = dict(sorted(sd2.items())) + + models_are_equal = True + for ten1, ten2 in zip(sd1.values(), sd2.values()): + if (ten1 - ten2).abs().max() > 1e-3: + models_are_equal = False + + return models_are_equal + + def create_unet_lora_layers(unet: nn.Module): lora_attn_procs = {} for name in unet.attn_processors.keys(): @@ -94,6 +123,10 @@ def get_dummy_components(self): r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False ) + unet_lora_config = LoraConfig( + r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + ) + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) if self.has_two_text_encoders: @@ -120,7 +153,7 @@ def get_dummy_components(self): "unet_lora_layers": unet_lora_layers, "unet_lora_attn_procs": unet_lora_attn_procs, } - return pipeline_components, lora_components, text_lora_config + return pipeline_components, lora_components, text_lora_config, unet_lora_config def get_dummy_inputs(self, with_generator=True): batch_size = 1 @@ -166,7 +199,7 @@ def test_simple_inference(self): """ Tests a simple inference and makes sure it works as expected """ - components, _, _ = self.get_dummy_components() + components, _, _, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -180,7 +213,7 @@ def test_simple_inference_with_text_lora(self): Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ - components, _, text_lora_config = self.get_dummy_components() + components, _, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -208,7 +241,7 @@ def test_simple_inference_with_text_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ - components, _, text_lora_config = self.get_dummy_components() + components, _, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -252,7 +285,7 @@ def test_simple_inference_with_text_lora_fused(self): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ - components, _, text_lora_config = self.get_dummy_components() + components, _, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -289,7 +322,7 @@ def test_simple_inference_with_text_lora_unloaded(self): Tests a simple inference with lora attached to text encoder, then unloads the lora weights and makes sure it works as expected """ - components, _, text_lora_config = self.get_dummy_components() + components, _, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -327,7 +360,7 @@ def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. """ - components, _, text_lora_config = self.get_dummy_components() + components, _, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -387,7 +420,7 @@ def test_simple_inference_save_pretrained(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ - components, _, text_lora_config = self.get_dummy_components() + components, _, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -431,108 +464,680 @@ def test_simple_inference_save_pretrained(self): "Loading from saved checkpoints should give same results.", ) + def test_simple_inference_with_text_unet_lora_save_load(self): + """ + Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) -class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): - pipeline_class = StableDiffusionPipeline - scheduler_cls = DDIMScheduler - scheduler_kwargs = { - "beta_start": 0.00085, - "beta_end": 0.012, - "beta_schedule": "scaled_linear", - "clip_sample": False, - "set_alpha_to_one": False, - "steps_offset": 1, - } - unet_kwargs = { - "block_out_channels": (32, 64), - "layers_per_block": 2, - "sample_size": 32, - "in_channels": 4, - "out_channels": 4, - "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), - "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), - "cross_attention_dim": 32, - } - vae_kwargs = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) - @slow - @require_torch_gpu - def test_integration_logits_with_scale(self): - path = "runwayml/stable-diffusion-v1-5" - lora_id = "takuma104/lora-test-text-encoder-lora-target" + pipe.text_encoder.add_adapter(text_lora_config) + pipe.unet.add_adapter(unet_lora_config) - pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) - pipe.load_lora_weights(lora_id) - pipe = pipe.to("cuda") + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + unet_state_dict = get_peft_model_state_dict(pipe.unet) + if self.has_two_text_encoders: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + unet_lora_layers=unet_state_dict, + safe_serialization=False, + ) + else: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + unet_lora_layers=unet_state_dict, + safe_serialization=False, + ) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + pipe.unload_lora_weights() + + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) self.assertTrue( - self.check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder 2", + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", ) - prompt = "a red sks dog" + def test_simple_inference_with_text_unet_lora_and_scale(self): + """ + Tests a simple inference with lora attached on the text encoder + Unet + scale argument + and makes sure it works as expected + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) - images = pipe( - prompt=prompt, - num_inference_steps=15, - cross_attention_kwargs={"scale": 0.5}, - generator=torch.manual_seed(0), - output_type="np", + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + pipe.unet.add_adapter(unet_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + + output_lora_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} ).images + self.assertTrue( + not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + "Lora + scale should change the output", + ) - expected_slice_scale = np.array([0.307, 0.283, 0.310, 0.310, 0.300, 0.314, 0.336, 0.314, 0.321]) + output_lora_0_scale = pipe( + **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} + ).images + self.assertTrue( + np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), + "Lora + 0 scale should lead to same result as no LoRA", + ) - predicted_slice = images[0, -3:, -3:, -1].flatten() + self.assertTrue( + pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, + "The scaling parameter has not been correctly restored!", + ) - self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + def test_simple_inference_with_text_lora_unet_fused(self): + """ + Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model + and makes sure it works as expected - with unet + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) - @slow - @require_torch_gpu - def test_integration_logits_no_scale(self): - path = "runwayml/stable-diffusion-v1-5" - lora_id = "takuma104/lora-test-text-encoder-lora-target" + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) - pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) - pipe.load_lora_weights(lora_id) - pipe = pipe.to("cuda") + pipe.text_encoder.add_adapter(text_lora_config) + pipe.unet.add_adapter(unet_lora_config) + + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.fuse_lora() + # Fusing should still keep the LoRA layers + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet") + + if self.has_two_text_encoders: + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertFalse( + np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + ) + + def test_simple_inference_with_text_unet_lora_unloaded(self): + """ + Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights + and makes sure it works as expected + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + pipe.unet.add_adapter(unet_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.unload_lora_weights() + # unloading should remove the LoRA layers + self.assertFalse( + self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" + ) + self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet") + + if self.has_two_text_encoders: + self.assertFalse( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2" + ) + ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( - self.check_if_lora_correctly_set(pipe.text_encoder), - "Lora not correctly set in text encoder", + np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) - prompt = "a red sks dog" + def test_simple_inference_with_text_unet_lora_unfused(self): + """ + Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights + and makes sure it works as expected + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) - images = pipe(prompt=prompt, num_inference_steps=30, generator=torch.manual_seed(0), output_type="np").images + pipe.text_encoder.add_adapter(text_lora_config) + pipe.unet.add_adapter(unet_lora_config) - expected_slice_scale = np.array([0.074, 0.064, 0.073, 0.0842, 0.069, 0.0641, 0.0794, 0.076, 0.084]) + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") - predicted_slice = images[0, -3:, -3:, -1].flatten() + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) - self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + pipe.fuse_lora() + output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images -class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): - has_two_text_encoders = True - pipeline_class = StableDiffusionXLPipeline - scheduler_cls = EulerDiscreteScheduler - scheduler_kwargs = { - "beta_start": 0.00085, - "beta_end": 0.012, - "beta_schedule": "scaled_linear", - "timestep_spacing": "leading", - "steps_offset": 1, - } - unet_kwargs = { - "block_out_channels": (32, 64), - "layers_per_block": 2, + pipe.unfuse_lora() + + output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + # unloading should remove the LoRA layers + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers") + + if self.has_two_text_encoders: + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + ) + + # Fuse and unfuse should lead to the same results + self.assertTrue( + np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3), + "Fused lora should change the output", + ) + + def test_simple_inference_with_text_unet_multi_adapter(self): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + multiple adapters and set them + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-2") + + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.set_adapters("adapter-1") + + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.set_adapters("adapter-2") + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.set_adapters(["adapter-1", "adapter-2"]) + + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images + + # Fuse and unfuse should lead to the same results + self.assertFalse( + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + self.assertFalse( + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 1 and mixed adapters should give different results", + ) + + self.assertFalse( + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 2 and mixed adapters should give different results", + ) + + pipe.disable_lora() + + output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + def test_lora_fuse_nan(self): + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float( + "inf" + ) + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(safe_fusing=False) + + out = pipe("test", num_inference_steps=2, output_type="np").images + + self.assertTrue(np.isnan(out).all()) + + def test_get_adapters(self): + """ + Tests a simple usecase where we attach multiple adapters and check if the results + are the expected results + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + + adapter_names = pipe.get_active_adapters() + self.assertListEqual(adapter_names, ["adapter-1"]) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + pipe.unet.add_adapter(unet_lora_config, "adapter-2") + + adapter_names = pipe.get_active_adapters() + self.assertListEqual(adapter_names, ["adapter-2"]) + + pipe.set_adapters(["adapter-1", "adapter-2"]) + self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"]) + + def test_get_list_adapters(self): + """ + Tests a simple usecase where we attach multiple adapters and check if the results + are the expected results + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + + adapter_names = pipe.get_list_adapters() + self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]}) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + pipe.unet.add_adapter(unet_lora_config, "adapter-2") + + adapter_names = pipe.get_list_adapters() + self.assertDictEqual( + adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]} + ) + + pipe.set_adapters(["adapter-1", "adapter-2"]) + self.assertDictEqual( + pipe.get_list_adapters(), {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]} + ) + + pipe.unet.add_adapter(unet_lora_config, "adapter-3") + self.assertDictEqual( + pipe.get_list_adapters(), + {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]}, + ) + + @unittest.skip("This is failing for now - need to investigate") + def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): + """ + Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights + and makes sure it works as expected + """ + components, _, text_lora_config, unet_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config) + pipe.unet.add_adapter(unet_lora_config) + + self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) + + if self.has_two_text_encoders: + pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) + + # Just makes sure it works.. + _ = pipe(**inputs, generator=torch.manual_seed(0)).images + + +class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): + pipeline_class = StableDiffusionPipeline + scheduler_cls = DDIMScheduler + scheduler_kwargs = { + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "clip_sample": False, + "set_alpha_to_one": False, + "steps_offset": 1, + } + unet_kwargs = { + "block_out_channels": (32, 64), + "layers_per_block": 2, + "sample_size": 32, + "in_channels": 4, + "out_channels": 4, + "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), + "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), + "cross_attention_dim": 32, + } + vae_kwargs = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + + @slow + @require_torch_gpu + def test_integration_move_lora_cpu(self): + path = "runwayml/stable-diffusion-v1-5" + lora_id = "takuma104/lora-test-text-encoder-lora-target" + + pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) + pipe.load_lora_weights(lora_id, adapter_name="adapter-1") + pipe.load_lora_weights(lora_id, adapter_name="adapter-2") + pipe = pipe.to("cuda") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder), + "Lora not correctly set in text encoder", + ) + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.unet), + "Lora not correctly set in text encoder", + ) + + # We will offload the first adapter in CPU and check if the offloading + # has been performed correctly + pipe.set_lora_device(["adapter-1"], "cpu") + + for name, module in pipe.unet.named_modules(): + if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device == torch.device("cpu")) + elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device != torch.device("cpu")) + + for name, module in pipe.text_encoder.named_modules(): + if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device == torch.device("cpu")) + elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device != torch.device("cpu")) + + pipe.set_lora_device(["adapter-1"], 0) + + for n, m in pipe.unet.named_modules(): + if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): + self.assertTrue(m.weight.device != torch.device("cpu")) + + for n, m in pipe.text_encoder.named_modules(): + if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)): + self.assertTrue(m.weight.device != torch.device("cpu")) + + pipe.set_lora_device(["adapter-1", "adapter-2"], "cuda") + + for n, m in pipe.unet.named_modules(): + if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): + self.assertTrue(m.weight.device != torch.device("cpu")) + + for n, m in pipe.text_encoder.named_modules(): + if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): + self.assertTrue(m.weight.device != torch.device("cpu")) + + @slow + @require_torch_gpu + def test_integration_logits_with_scale(self): + path = "runwayml/stable-diffusion-v1-5" + lora_id = "takuma104/lora-test-text-encoder-lora-target" + + pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) + pipe.load_lora_weights(lora_id) + pipe = pipe.to("cuda") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder), + "Lora not correctly set in text encoder 2", + ) + + prompt = "a red sks dog" + + images = pipe( + prompt=prompt, + num_inference_steps=15, + cross_attention_kwargs={"scale": 0.5}, + generator=torch.manual_seed(0), + output_type="np", + ).images + + expected_slice_scale = np.array([0.307, 0.283, 0.310, 0.310, 0.300, 0.314, 0.336, 0.314, 0.321]) + + predicted_slice = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + @slow + @require_torch_gpu + def test_integration_logits_no_scale(self): + path = "runwayml/stable-diffusion-v1-5" + lora_id = "takuma104/lora-test-text-encoder-lora-target" + + pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) + pipe.load_lora_weights(lora_id) + pipe = pipe.to("cuda") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder), + "Lora not correctly set in text encoder", + ) + + prompt = "a red sks dog" + + images = pipe(prompt=prompt, num_inference_steps=30, generator=torch.manual_seed(0), output_type="np").images + + expected_slice_scale = np.array([0.074, 0.064, 0.073, 0.0842, 0.069, 0.0641, 0.0794, 0.076, 0.084]) + + predicted_slice = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + @nightly + @require_torch_gpu + def test_integration_logits_multi_adapter(self): + path = "stabilityai/stable-diffusion-xl-base-1.0" + lora_id = "CiroN2022/toy-face" + + pipe = StableDiffusionXLPipeline.from_pretrained(path, torch_dtype=torch.float16) + pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy") + pipe = pipe.to("cuda") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.unet), + "Lora not correctly set in Unet", + ) + + prompt = "toy_face of a hacker with a hoodie" + + lora_scale = 0.9 + + images = pipe( + prompt=prompt, + num_inference_steps=30, + generator=torch.manual_seed(0), + cross_attention_kwargs={"scale": lora_scale}, + output_type="np", + ).images + expected_slice_scale = np.array([0.538, 0.539, 0.540, 0.540, 0.542, 0.539, 0.538, 0.541, 0.539]) + + predicted_slice = images[0, -3:, -3:, -1].flatten() + # import pdb; pdb.set_trace() + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipe.set_adapters("pixel") + + prompt = "pixel art, a hacker with a hoodie, simple, flat colors" + images = pipe( + prompt, + num_inference_steps=30, + guidance_scale=7.5, + cross_attention_kwargs={"scale": lora_scale}, + generator=torch.manual_seed(0), + output_type="np", + ).images + + predicted_slice = images[0, -3:, -3:, -1].flatten() + expected_slice_scale = np.array( + [0.61973065, 0.62018543, 0.62181497, 0.61933696, 0.6208608, 0.620576, 0.6200281, 0.62258327, 0.6259889] + ) + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + # multi-adapter inference + pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0]) + images = pipe( + prompt, + num_inference_steps=30, + guidance_scale=7.5, + cross_attention_kwargs={"scale": 1.0}, + generator=torch.manual_seed(0), + output_type="np", + ).images + predicted_slice = images[0, -3:, -3:, -1].flatten() + expected_slice_scale = np.array([0.5977, 0.5985, 0.6039, 0.5976, 0.6025, 0.6036, 0.5946, 0.5979, 0.5998]) + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + # Lora disabled + pipe.disable_lora() + images = pipe( + prompt, + num_inference_steps=30, + guidance_scale=7.5, + cross_attention_kwargs={"scale": lora_scale}, + generator=torch.manual_seed(0), + output_type="np", + ).images + predicted_slice = images[0, -3:, -3:, -1].flatten() + expected_slice_scale = np.array([0.54625, 0.5473, 0.5495, 0.5465, 0.5476, 0.5461, 0.5452, 0.5485, 0.5493]) + self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) + + +class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): + has_two_text_encoders = True + pipeline_class = StableDiffusionXLPipeline + scheduler_cls = EulerDiscreteScheduler + scheduler_kwargs = { + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "timestep_spacing": "leading", + "steps_offset": 1, + } + unet_kwargs = { + "block_out_channels": (32, 64), + "layers_per_block": 2, "sample_size": 32, "in_channels": 4, "out_channels": 4, @@ -555,3 +1160,606 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): "latent_channels": 4, "sample_size": 128, } + + +@slow +@require_torch_gpu +class LoraIntegrationTests(unittest.TestCase): + def tearDown(self): + import gc + + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_dreambooth_old_format(self): + generator = torch.Generator("cpu").manual_seed(0) + + lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe( + "A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + release_memory(pipe) + + def test_dreambooth_text_encoder_new_format(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/lora-trained" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + release_memory(pipe) + + def test_a1111(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to( + torch_device + ) + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_lycoris(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16" + ).to(torch_device) + lora_model_id = "hf-internal-testing/edgLycorisMugler-light" + lora_filename = "edgLycorisMugler-light.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_a1111_with_model_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_a1111_with_sequential_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_kohya_sd_v15_with_higher_dimensions(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + lora_model_id = "hf-internal-testing/urushisato-lora" + lora_filename = "urushisato_v15.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.7165, 0.6616, 0.5833, 0.7504, 0.6718, 0.587, 0.6871, 0.6361, 0.5694]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_vanilla_funetuning(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + release_memory(pipe) + + def test_unload_kohya_lora(self): + generator = torch.manual_seed(0) + prompt = "masterpiece, best quality, mountain" + num_inference_steps = 2 + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + initial_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + initial_images = initial_images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" + lora_filename = "Colored_Icons_by_vizsumit.safetensors" + + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images = lora_images[0, -3:, -3:, -1].flatten() + + pipe.unload_lora_weights() + generator = torch.manual_seed(0) + unloaded_lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(initial_images, lora_images)) + self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + release_memory(pipe) + + def test_load_unload_load_kohya_lora(self): + # This test ensures that a Kohya-style LoRA can be safely unloaded and then loaded + # without introducing any side-effects. Even though the test uses a Kohya-style + # LoRA, the underlying adapter handling mechanism is format-agnostic. + generator = torch.manual_seed(0) + prompt = "masterpiece, best quality, mountain" + num_inference_steps = 2 + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + initial_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + initial_images = initial_images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" + lora_filename = "Colored_Icons_by_vizsumit.safetensors" + + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images = lora_images[0, -3:, -3:, -1].flatten() + + pipe.unload_lora_weights() + generator = torch.manual_seed(0) + unloaded_lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(initial_images, lora_images)) + self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + + # make sure we can load a LoRA again after unloading and they don't have + # any undesired effects. + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images_again = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) + release_memory(pipe) + + +@slow +@require_torch_gpu +class LoraSDXLIntegrationTests(unittest.TestCase): + def tearDown(self): + import gc + + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_sdxl_0_9_lora_one(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") + lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora" + lora_filename = "daiton-xl-lora-test.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3838, 0.3482, 0.3588, 0.3162, 0.319, 0.3369, 0.338, 0.3366, 0.3213]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_sdxl_0_9_lora_two(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") + lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora" + lora_filename = "saijo.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3137, 0.3269, 0.3355, 0.255, 0.2577, 0.2563, 0.2679, 0.2758, 0.2626]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_sdxl_0_9_lora_three(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") + lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora" + lora_filename = "kame_sdxl_v2-000020-16rank.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468]) + + self.assertTrue(np.allclose(images, expected, atol=5e-3)) + release_memory(pipe) + + def test_sdxl_1_0_lora(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + release_memory(pipe) + + def test_sdxl_1_0_lora_fusion(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + pipe.fuse_lora() + # We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being + # silently deleted - otherwise this will CPU OOM + pipe.unload_lora_weights() + + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + # This way we also test equivalence between LoRA fusion and the non-fusion behaviour. + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + release_memory(pipe) + + def test_sdxl_1_0_lora_unfusion(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_with_fusion = images[0, -3:, -3:, -1].flatten() + + pipe.unfuse_lora() + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_without_fusion = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3)) + release_memory(pipe) + + def test_sdxl_1_0_lora_unfusion_effectivity(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + original_image_slice = images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + + generator = torch.Generator().manual_seed(0) + _ = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + pipe.unfuse_lora() + + # We need to unload the lora weights - in the old API unfuse led to unloading the adapter weights + pipe.unload_lora_weights() + + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_without_fusion_slice = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(original_image_slice, images_without_fusion_slice, atol=1e-3)) + release_memory(pipe) + + def test_sdxl_1_0_lora_fusion_efficiency(self): + generator = torch.Generator().manual_seed(0) + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 + ) + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + start_time = time.time() + for _ in range(3): + pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + end_time = time.time() + elapsed_time_non_fusion = end_time - start_time + + del pipe + + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 + ) + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16) + pipe.fuse_lora() + # We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being + # silently deleted - otherwise this will CPU OOM + pipe.unload_lora_weights() + + pipe.enable_model_cpu_offload() + + start_time = time.time() + generator = torch.Generator().manual_seed(0) + for _ in range(3): + pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + end_time = time.time() + elapsed_time_fusion = end_time - start_time + + self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion) + release_memory(pipe) + + def test_sdxl_1_0_last_ben(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + lora_model_id = "TheLastBen/Papercut_SDXL" + lora_filename = "papercut.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_sdxl_1_0_fuse_unfuse_all(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) + text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) + text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) + unet_sd = copy.deepcopy(pipe.unet.state_dict()) + + pipe.load_lora_weights( + "davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16 + ) + + fused_te_state_dict = pipe.text_encoder.state_dict() + fused_te_2_state_dict = pipe.text_encoder_2.state_dict() + unet_state_dict = pipe.unet.state_dict() + + for key, value in text_encoder_1_sd.items(): + self.assertTrue(torch.allclose(fused_te_state_dict[key], value)) + + for key, value in text_encoder_2_sd.items(): + self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value)) + + for key, value in unet_state_dict.items(): + self.assertTrue(torch.allclose(unet_state_dict[key], value)) + + pipe.fuse_lora() + pipe.unload_lora_weights() + + assert not state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict()) + assert not state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict()) + assert not state_dicts_almost_equal(unet_sd, pipe.unet.state_dict()) + release_memory(pipe) + del unet_sd, text_encoder_1_sd, text_encoder_2_sd + + def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + release_memory(pipe) + + def test_canny_lora(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet + ) + pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") + pipe.enable_sequential_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "corgi" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (768, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) + assert np.allclose(original_image, expected_image, atol=1e-04) + release_memory(pipe) + + @nightly + def test_sequential_fuse_unfuse(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) + + # 1. round + pipe.load_lora_weights("Pclanglais/TintinIA", torch_dtype=torch.float16) + pipe.to("cuda") + pipe.fuse_lora() + + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + image_slice = images[0, -3:, -3:, -1].flatten() + + pipe.unfuse_lora() + + # 2. round + pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style", torch_dtype=torch.float16) + pipe.fuse_lora() + pipe.unfuse_lora() + + # 3. round + pipe.load_lora_weights("ostris/crayon_style_lora_sdxl", torch_dtype=torch.float16) + pipe.fuse_lora() + pipe.unfuse_lora() + + # 4. back to 1st round + pipe.load_lora_weights("Pclanglais/TintinIA", torch_dtype=torch.float16) + pipe.fuse_lora() + + generator = torch.Generator().manual_seed(0) + images_2 = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + image_slice_2 = images_2[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3)) + release_memory(pipe)