From 99da518407bb9a5b6a396779ad3fc2bd0256ee38 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 4 Dec 2023 16:36:56 +0100 Subject: [PATCH 1/5] [Peft] fix saving / loading when unet is not "unet" --- src/diffusers/loaders/lora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index dde717959f8e..6d0b184a0691 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -816,10 +816,10 @@ def pack_weights(layers, prefix): raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( @@ -1376,10 +1376,10 @@ def pack_weights(layers, prefix): ) if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) + state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) cls.write_lora_layers( From 4045a1a8a4b9d18f99a3db0c5c29be9c17f7dccb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 5 Dec 2023 13:43:16 +0100 Subject: [PATCH 2/5] Update src/diffusers/loaders/lora.py Co-authored-by: Sayak Paul --- src/diffusers/loaders/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 6d0b184a0691..c60ba1b5b73f 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1380,7 +1380,7 @@ def pack_weights(layers, prefix): if text_encoder_lora_layers and text_encoder_2_lora_layers: state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, f"{cls.text_encoder}_2")) cls.write_lora_layers( state_dict=state_dict, From 597d1fd9888776d8baf9ac2414dd4a94003118a9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 7 Dec 2023 09:32:48 +0100 Subject: [PATCH 3/5] undo stablediffusion-xl changes --- src/diffusers/loaders/lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 7906a2fca284..6066f5efcda1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1376,11 +1376,11 @@ def pack_weights(layers, prefix): ) if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, cls.unet_name)) + state_dict.update(pack_weights(unet_lora_layers, "unet")) if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - state_dict.update(pack_weights(text_encoder_2_lora_layers, f"{cls.text_encoder}_2")) + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) cls.write_lora_layers( state_dict=state_dict, From 9c386f8d7fdd745448d52c00c79da54c78aa2a25 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 7 Dec 2023 09:51:34 +0100 Subject: [PATCH 4/5] use unet_name to get unet for lora helpers --- src/diffusers/loaders/lora.py | 42 ++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 6066f5efcda1..e76ea2fd3cf1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -876,6 +876,8 @@ def unload_lora_weights(self): >>> ... ``` """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if not USE_PEFT_BACKEND: if version.parse(__version__) > version.parse("0.23"): logger.warn( @@ -883,13 +885,13 @@ def unload_lora_weights(self): "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." ) - for _, module in self.unet.named_modules(): + for _, module in unet.named_modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) else: - recurse_remove_peft_layers(self.unet) - if hasattr(self.unet, "peft_config"): - del self.unet.peft_config + recurse_remove_peft_layers(unet) + if hasattr(unet, "peft_config"): + del unet.peft_config # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() @@ -928,7 +930,8 @@ def fuse_lora( ) if fuse_unet: - self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer @@ -981,13 +984,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if unfuse_unet: if not USE_PEFT_BACKEND: - self.unet.unfuse_lora() + unet.unfuse_lora() else: from peft.tuners.tuners_utils import BaseTunerLayer - for module in self.unet.modules(): + for module in unet.modules(): if isinstance(module, BaseTunerLayer): module.unmerge() @@ -1103,8 +1107,9 @@ def set_adapters( adapter_names: Union[List[str], str], adapter_weights: Optional[List[float]] = None, ): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet # Handle the UNET - self.unet.set_adapters(adapter_names, adapter_weights) + unet.set_adapters(adapter_names, adapter_weights) # Handle the Text Encoder if hasattr(self, "text_encoder"): @@ -1117,7 +1122,8 @@ def disable_lora(self): raise ValueError("PEFT backend is required for this method.") # Disable unet adapters - self.unet.disable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.disable_lora() # Disable text encoder adapters if hasattr(self, "text_encoder"): @@ -1130,7 +1136,8 @@ def enable_lora(self): raise ValueError("PEFT backend is required for this method.") # Enable unet adapters - self.unet.enable_lora() + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.enable_lora() # Enable text encoder adapters if hasattr(self, "text_encoder"): @@ -1152,7 +1159,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): adapter_names = [adapter_names] # Delete unet adapters - self.unet.delete_adapters(adapter_names) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet.delete_adapters(adapter_names) for adapter_name in adapter_names: # Delete text encoder adapters @@ -1185,8 +1193,8 @@ def get_active_adapters(self) -> List[str]: from peft.tuners.tuners_utils import BaseTunerLayer active_adapters = [] - - for module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for module in unet.modules(): if isinstance(module, BaseTunerLayer): active_adapters = module.active_adapters break @@ -1210,8 +1218,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]: if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) - if hasattr(self, "unet") and hasattr(self.unet, "peft_config"): - set_adapters["unet"] = list(self.unet.peft_config.keys()) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"): + set_adapters[self.unet_name] = list(self.unet.peft_config.keys()) return set_adapters @@ -1232,7 +1241,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, from peft.tuners.tuners_utils import BaseTunerLayer # Handle the UNET - for unet_module in self.unet.modules(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for unet_module in unet.modules(): if isinstance(unet_module, BaseTunerLayer): for adapter_name in adapter_names: unet_module.lora_A[adapter_name].to(device) From 0e276cd137e34ffae72b79214e2def816b67019c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 20 Dec 2023 11:22:18 +0100 Subject: [PATCH 5/5] use unet_name --- src/diffusers/loaders/ip_adapter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 158bde436374..3df0492380e5 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -149,9 +149,11 @@ def load_ip_adapter( self.feature_extractor = CLIPImageProcessor() # load ip-adapter into unet - self.unet._load_ip_adapter_weights(state_dict) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + unet._load_ip_adapter_weights(state_dict) def set_ip_adapter_scale(self, scale): - for attn_processor in self.unet.attn_processors.values(): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): attn_processor.scale = scale