diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ae5de5673100..e0aea2ce1493 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -212,6 +212,8 @@ title: Textual Inversion - local: api/loaders/unet title: UNet + - local: api/loaders/peft + title: PEFT title: Loaders - sections: - local: api/models/overview diff --git a/docs/source/en/api/loaders/peft.md b/docs/source/en/api/loaders/peft.md new file mode 100644 index 000000000000..1aff389d3142 --- /dev/null +++ b/docs/source/en/api/loaders/peft.md @@ -0,0 +1,25 @@ + + +# PEFT + +Diffusers supports working with adapters (such as [LoRA](../../using-diffusers/loading_adapters)) via the [`peft` library](https://huggingface.co/docs/peft/index). We provide a `PeftAdapterMixin` class to handle this for modeling classes in Diffusers (such as [`UNet2DConditionModel`]). + + + +Refer to [this doc](../../tutorials/using_peft_for_inference.md) to get an overview of how to work with `peft` in Diffusers for inference. + + + +## PeftAdapterMixin + +[[autodoc]] loaders.peft.PeftAdapterMixin \ No newline at end of file diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 45c8c97c76eb..d7855206a287 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate -from ..utils.import_utils import is_torch_available, is_transformers_available +from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available def text_encoder_lora_state_dict(text_encoder): @@ -64,6 +64,8 @@ def text_encoder_attn_modules(text_encoder): _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] +_import_structure["peft"] = ["PeftAdapterMixin"] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): @@ -76,6 +78,8 @@ def text_encoder_attn_modules(text_encoder): from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin + + from .peft import PeftAdapterMixin else: import sys diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py new file mode 100644 index 000000000000..a7f1e1c938f0 --- /dev/null +++ b/src/diffusers/loaders/peft.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available + + +class PeftAdapterMixin: + """ + A class containing all functions for loading and using adapters weights that are supported in PEFT library. For + more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT + library: https://huggingface.co/docs/peft/index. + + + With this mixin, if the correct PEFT version is installed, it is possible to: + + - Attach new adapters in the model. + - Attach multiple adapters and iteratively activate / deactivate them. + - Activate / deactivate all adapters from the model. + - Get a list of the active adapters. + """ + + _hf_peft_config_loaded = False + + def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: + r""" + Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned + to the adapter to follow the convention of the PEFT library. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT + [documentation](https://huggingface.co/docs/peft). + + Args: + adapter_config (`[~peft.PeftConfig]`): + The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt + methods. + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not is_peft_available(): + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") + + from peft import PeftConfig, inject_adapter_in_model + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + elif adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + + # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is + # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. + adapter_config.base_model_name_or_path = None + inject_adapter_in_model(adapter_config, self, adapter_name) + self.set_adapter(adapter_name) + + def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: + """ + Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Args: + adapter_name (Union[str, List[str]])): + The list of adapters to set or the adapter name in case of single adapter. + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + missing = set(adapter_name) - set(self.peft_config) + if len(missing) > 0: + raise ValueError( + f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." + f" current loaded adapters are: {list(self.peft_config.keys())}" + ) + + from peft.tuners.tuners_utils import BaseTunerLayer + + _adapters_has_been_set = False + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + # Previous versions of PEFT does not support multi-adapter inference + elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: + raise ValueError( + "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." + " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" + ) + else: + module.active_adapter = adapter_name + _adapters_has_been_set = True + + if not _adapters_has_been_set: + raise ValueError( + "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." + ) + + def disable_adapters(self) -> None: + r""" + Disable all adapters attached to the model and fallback to inference with the base model only. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + # support for older PEFT versions + module.disable_adapters = True + + def enable_adapters(self) -> None: + """ + Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the + list of adapters to enable. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=True) + else: + # support for older PEFT versions + module.disable_adapters = False + + def active_adapters(self) -> List[str]: + """ + Gets the current list of active adapters of the model. + + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + """ + check_peft_version(min_version=MIN_PEFT_VERSION) + + if not is_peft_available(): + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + return module.active_adapter diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 546c5b20f937..445c3ca71caf 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -32,12 +32,10 @@ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, - MIN_PEFT_VERSION, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, - check_peft_version, deprecate, is_accelerate_available, is_torch_version, @@ -197,7 +195,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None - _hf_peft_config_loaded = False def __init__(self): super().__init__() @@ -303,153 +300,6 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) - def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: - r""" - Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned - to the adapter to follow the convention of the PEFT library. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT - [documentation](https://huggingface.co/docs/peft). - - Args: - adapter_config (`[~peft.PeftConfig]`): - The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt - methods. - adapter_name (`str`, *optional*, defaults to `"default"`): - The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - from peft import PeftConfig, inject_adapter_in_model - - if not self._hf_peft_config_loaded: - self._hf_peft_config_loaded = True - elif adapter_name in self.peft_config: - raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") - - if not isinstance(adapter_config, PeftConfig): - raise ValueError( - f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." - ) - - # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is - # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. - adapter_config.base_model_name_or_path = None - inject_adapter_in_model(adapter_config, self, adapter_name) - self.set_adapter(adapter_name) - - def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: - """ - Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Args: - adapter_name (Union[str, List[str]])): - The list of adapters to set or the adapter name in case of single adapter. - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - if isinstance(adapter_name, str): - adapter_name = [adapter_name] - - missing = set(adapter_name) - set(self.peft_config) - if len(missing) > 0: - raise ValueError( - f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." - f" current loaded adapters are: {list(self.peft_config.keys())}" - ) - - from peft.tuners.tuners_utils import BaseTunerLayer - - _adapters_has_been_set = False - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - if hasattr(module, "set_adapter"): - module.set_adapter(adapter_name) - # Previous versions of PEFT does not support multi-adapter inference - elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: - raise ValueError( - "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." - " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" - ) - else: - module.active_adapter = adapter_name - _adapters_has_been_set = True - - if not _adapters_has_been_set: - raise ValueError( - "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." - ) - - def disable_adapters(self) -> None: - r""" - Disable all adapters attached to the model and fallback to inference with the base model only. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - if hasattr(module, "enable_adapters"): - module.enable_adapters(enabled=False) - else: - # support for older PEFT versions - module.disable_adapters = True - - def enable_adapters(self) -> None: - """ - Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the - list of adapters to enable. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - if hasattr(module, "enable_adapters"): - module.enable_adapters(enabled=True) - else: - # support for older PEFT versions - module.disable_adapters = False - - def active_adapters(self) -> List[str]: - """ - Gets the current list of active adapters of the model. - - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - """ - check_peft_version(min_version=MIN_PEFT_VERSION) - - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - return module.active_adapter - def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 8ada0a7c08a5..6b52ea344d41 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -6,7 +6,7 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin +from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ..utils import BaseOutput from .attention import BasicTransformerBlock from .attention_processor import ( @@ -33,7 +33,7 @@ class PriorTransformerOutput(BaseOutput): predicted_image_embedding: torch.FloatTensor -class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): """ A Prior Transformer model. diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e0427858b8a4..7b4f9f5594ea 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin +from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from .activations import get_activation from .attention_processor import ( @@ -68,7 +68,7 @@ class UNet2DConditionOutput(BaseOutput): sample: torch.FloatTensor = None -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/uvit_2d.py index a49c77a51b02..c0e224562cf2 100644 --- a/src/diffusers/models/uvit_2d.py +++ b/src/diffusers/models/uvit_2d.py @@ -21,6 +21,7 @@ from torch.utils.checkpoint import checkpoint from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import PeftAdapterMixin from .attention import BasicTransformerBlock, SkipFFTransformerBlock from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -35,7 +36,7 @@ from .resnet import Downsample2D, Upsample2D -class UVit2DModel(ModelMixin, ConfigMixin): +class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index d4502639cebc..8b494fa32476 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -20,7 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import UNet2DConditionLoadersMixin +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -34,7 +34,7 @@ from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm -class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): unet_name = "prior" _supports_gradient_checkpointing = True