diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4554016bdd53..e0427858b8a4 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -829,6 +829,17 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 2fd629b2f8d9..fc8695e064b5 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging +from ..utils import BaseOutput, deprecate, logging from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -503,6 +503,18 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 6f95112c3d50..86e3cfefaf4f 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1034,6 +1034,17 @@ def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + def forward( self, sample: torch.FloatTensor, diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index 584f8a6c4f4a..fbc12d74f1aa 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -151,9 +151,7 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): unet_lora_sd = unet_lora_state_dict(unet) # Unload LoRA. - for module in unet.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) + unet.unload_lora() return unet_lora_parameters, unet_lora_sd @@ -230,9 +228,7 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): unet_lora_sd = unet_lora_state_dict(unet) # Unload LoRA. - for module in unet.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) + unet.unload_lora() return unet_lora_sd @@ -1545,9 +1541,7 @@ def test_lora_on_off(self, expected_max_diff=1e-3): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample # Unload LoRA. - for module in model.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) + model.unload_lora() with torch.no_grad(): new_sample = model(**inputs_dict).sample