diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py index c6419b44daeb..20c8d0fdf0f1 100644 --- a/examples/research_projects/controlnetxs/controlnetxs.py +++ b/examples/research_projects/controlnetxs/controlnetxs.py @@ -494,9 +494,7 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: """ return self.control_model.attn_processors - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -509,7 +507,7 @@ def set_attn_processor( processor. This is strongly recommended when setting trainable attention processors. """ - self.control_model.set_attn_processor(processor, _remove_lora) + self.control_model.set_attn_processor(processor) def set_default_attn_processor(self): """ diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index bbd01a995061..424e95f0843e 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -980,7 +980,7 @@ def unload_lora_weights(self): if not USE_PEFT_BACKEND: if version.parse(__version__) > version.parse("0.23"): - logger.warn( + logger.warning( "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights," "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." ) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 23a3e2bb3791..ac9563e186bb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -373,29 +373,14 @@ def set_attention_slice(self, slice_size: int) -> None: self.set_processor(processor) - def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + def set_processor(self, processor: "AttnProcessor") -> None: r""" Set the attention processor to use. Args: processor (`AttnProcessor`): The attention processor to use. - _remove_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to remove LoRA layers from the model. """ - if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: - deprecate( - "set_processor to offload LoRA", - "0.26.0", - "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", - ) - # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete - # We need to remove all LoRA layers - # Don't forget to remove ALL `_remove_lora` from the codebase - for module in self.modules(): - if hasattr(module, "set_lora_layer"): - module.set_lora_layer(None) - # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index ae2d90c548f8..10a3ae58de9f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -182,9 +182,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -208,9 +206,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -232,7 +230,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) @apply_forward_hook def encode( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 0b7f8d1f5336..dbafb4571d4a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -267,9 +267,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -293,9 +291,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -314,7 +312,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) @apply_forward_hook def encode( diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index d92423eafc31..ca670fec4b28 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -212,9 +212,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -238,9 +236,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -262,7 +260,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) @apply_forward_hook def encode( diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 3139bb2a5c6c..1102f4f9d36d 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -534,9 +534,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -560,9 +558,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -584,7 +582,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 6c5e406ad378..8ada0a7c08a5 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -192,9 +192,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -218,9 +216,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -242,7 +240,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def forward( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 623e4d88d564..4554016bdd53 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -643,9 +643,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -669,9 +667,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -692,7 +690,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 3c76b5aa8452..2fd629b2f8d9 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -375,9 +375,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -401,9 +399,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -465,7 +463,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 0bbc573e7df1..b5f0302b4a43 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -549,9 +549,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -575,9 +573,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -641,7 +639,7 @@ def set_default_attn_processor(self) -> None: f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/uvit_2d.py index 14dd8aee8e89..a49c77a51b02 100644 --- a/src/diffusers/models/uvit_2d.py +++ b/src/diffusers/models/uvit_2d.py @@ -237,9 +237,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -263,9 +261,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -287,7 +285,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) class UVit2DConvEmbed(nn.Module): diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index e855c2f0d6f1..d39b2c99ddd0 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -538,9 +538,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -564,9 +562,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -588,7 +586,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 7c9936a0bd4e..6f95112c3d50 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -848,9 +848,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -874,9 +872,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -897,7 +895,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def set_attention_slice(self, slice_size): r""" diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index a7d9e32fb6c9..d4502639cebc 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -91,9 +91,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -117,9 +115,9 @@ def set_attn_processor( def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) + module.set_processor(processor) else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -141,7 +139,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor, _remove_lora=True) + self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index 7d6d30169455..09bb87c85163 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -61,7 +61,8 @@ ) -def text_encoder_attn_modules(text_encoder): +def text_encoder_attn_modules(text_encoder: nn.Module): + """Fetches the attention modules from `text_encoder`.""" attn_modules = [] if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): @@ -75,7 +76,8 @@ def text_encoder_attn_modules(text_encoder): return attn_modules -def text_encoder_lora_state_dict(text_encoder): +def text_encoder_lora_state_dict(text_encoder: nn.Module): + """Returns the LoRA state dict of the `text_encoder`. Assumes that `_modify_text_encoder()` was already called on it.""" state_dict = {} for name, module in text_encoder_attn_modules(text_encoder): @@ -95,6 +97,8 @@ def text_encoder_lora_state_dict(text_encoder): def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + """Creates and returns the LoRA state dict for the UNet.""" + # So that we accidentally don't end up using the in-place modified UNet. unet_lora_parameters = [] for attn_processor_name, attn_processor in unet.attn_processors.items(): @@ -145,10 +149,17 @@ def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) - return unet_lora_parameters, unet_lora_state_dict(unet) + unet_lora_sd = unet_lora_state_dict(unet) + # Unload LoRA. + for module in unet.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + return unet_lora_parameters, unet_lora_sd def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + """Creates and returns the LoRA state dict for the 3D UNet.""" for attn_processor_name in unet.attn_processors.keys(): has_cross_attention = attn_processor_name.endswith("attn2.processor") and not ( attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".") @@ -216,10 +227,18 @@ def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): attn_module.to_v.lora_layer.up.weight += 1 attn_module.to_out[0].lora_layer.up.weight += 1 - return unet_lora_state_dict(unet) + unet_lora_sd = unet_lora_state_dict(unet) + + # Unload LoRA. + for module in unet.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + return unet_lora_sd def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): + """Randomizes the LoRA params if specified.""" if not isinstance(lora_attn_parameters, dict): with torch.no_grad(): for parameter in lora_attn_parameters: @@ -1441,6 +1460,7 @@ def test_save_load_fused_lora_modules(self): class UNet2DConditionLoRAModelTests(unittest.TestCase): model_class = UNet2DConditionModel main_input_name = "sample" + lora_rank = 4 @property def dummy_input(self): @@ -1489,7 +1509,7 @@ def test_lora_processors(self): with torch.no_grad(): sample1 = model(**inputs_dict).sample - _, lora_params = create_unet_lora_layers(model) + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) # make sure we can set a list of attention processors model.load_attn_procs(lora_params) @@ -1522,13 +1542,16 @@ def test_lora_on_off(self, expected_max_diff=1e-3): with torch.no_grad(): old_sample = model(**inputs_dict).sample - _, lora_params = create_unet_lora_layers(model) + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) model.load_attn_procs(lora_params) with torch.no_grad(): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - model.set_default_attn_processor() + # Unload LoRA. + for module in model.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) with torch.no_grad(): new_sample = model(**inputs_dict).sample @@ -1552,7 +1575,7 @@ def test_lora_xformers_on_off(self, expected_max_diff=6e-4): torch.manual_seed(0) model = self.model_class(**init_dict) model.to(torch_device) - _, lora_params = create_unet_lora_layers(model) + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) model.load_attn_procs(lora_params) # default