diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index e3222357ae9b..8cc17281d016 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -23,6 +23,7 @@
import shutil
import warnings
from pathlib import Path
+from typing import Dict
import numpy as np
import torch
@@ -50,7 +51,10 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
+from diffusers.loaders import (
+ LoraLoaderMixin,
+ text_encoder_lora_state_dict,
+)
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
@@ -60,7 +64,7 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
-from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
+from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds
+def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
+ r"""
+ Returns:
+ a state dict containing just the attention processor parameters.
+ """
+ attn_processors = unet.attn_processors
+
+ attn_processors_state_dict = {}
+
+ for attn_processor_key, attn_processor in attn_processors.items():
+ for parameter_key, parameter in attn_processor.state_dict().items():
+ attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
+
+ return attn_processors_state_dict
+
+
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -833,6 +853,7 @@ def main(args):
# Set correct lora layers
unet_lora_attn_procs = {}
+ unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
@@ -850,35 +871,18 @@ def main(args):
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
- unet_lora_attn_procs[name] = lora_attn_processor_class(
- hidden_size=hidden_size,
- cross_attention_dim=cross_attention_dim,
- rank=args.rank,
- )
+
+ module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
+ unet_lora_attn_procs[name] = module
+ unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs)
- unet_lora_layers = AttnProcsLayers(unet.attn_processors)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
- # So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
- # we first load a dummy pipeline with the text encoder and then do the monkey-patching.
- text_encoder_lora_layers = None
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
- text_lora_attn_procs = {}
- for name, module in text_encoder.named_modules():
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
- text_lora_attn_procs[name] = LoRAAttnProcessor(
- hidden_size=module.out_proj.out_features,
- cross_attention_dim=None,
- rank=args.rank,
- )
- text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
- temp_pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path, text_encoder=text_encoder
- )
- temp_pipeline._modify_text_encoder(text_lora_attn_procs)
- text_encoder = temp_pipeline.text_encoder
- del temp_pipeline
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
@@ -887,23 +891,13 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
- if args.train_text_encoder:
- text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
- unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
-
for model in models:
- state_dict = model.state_dict()
-
- if (
- text_encoder_lora_layers is not None
- and text_encoder_keys is not None
- and state_dict.keys() == text_encoder_keys
- ):
- # text encoder
- text_encoder_lora_layers_to_save = state_dict
- elif state_dict.keys() == unet_keys:
- # unet
- unet_lora_layers_to_save = state_dict
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
+ text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
@@ -915,27 +909,24 @@ def save_model_hook(models, weights, output_dir):
)
def load_model_hook(models, input_dir):
- # Note we DON'T pass the unet and text encoder here an purpose
- # so that the we don't accidentally override the LoRA layers of
- # unet_lora_layers and text_encoder_lora_layers which are stored in `models`
- # with new torch.nn.Modules / weights. We simply use the pipeline class as
- # an easy way to load the lora checkpoints
- temp_pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- revision=args.revision,
- torch_dtype=weight_dtype,
- )
- temp_pipeline.load_lora_weights(input_dir)
+ unet_ = None
+ text_encoder_ = None
- # load lora weights into models
- models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
- if len(models) > 1:
- models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
+ while len(models) > 0:
+ model = models.pop()
- # delete temporary pipeline and pop models
- del temp_pipeline
- for _ in range(len(models)):
- models.pop()
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
+ text_encoder_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
+ )
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -965,9 +956,9 @@ def load_model_hook(models, input_dir):
# Optimizer creation
params_to_optimize = (
- itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
+ itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
- else unet_lora_layers.parameters()
+ else unet_lora_parameters
)
optimizer = optimizer_class(
params_to_optimize,
@@ -1056,12 +1047,12 @@ def compute_text_embeddings(prompt):
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
- unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
- unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet_lora_layers, optimizer, train_dataloader, lr_scheduler
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -1210,9 +1201,9 @@ def compute_text_embeddings(prompt):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
- itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
+ itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
- else unet_lora_layers.parameters()
+ else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -1301,15 +1292,17 @@ def compute_text_embeddings(prompt):
pipeline_args = {"prompt": args.validation_prompt}
if args.validation_images is None:
- images = [
- pipeline(**pipeline_args, generator=generator).images[0]
- for _ in range(args.num_validation_images)
- ]
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, generator=generator).images[0]
+ images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
- image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
@@ -1332,12 +1325,16 @@ def compute_text_embeddings(prompt):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
- unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
+ unet_lora_layers = unet_attn_processors_state_dict(unet)
- if text_encoder is not None:
+ if text_encoder is not None and args.train_text_encoder:
+ text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = text_encoder.to(torch.float32)
- text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
+ else:
+ text_encoder_lora_layers = None
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index 525bb446b77e..89ce88455e20 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -20,6 +20,7 @@
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
+from torch import nn
from .models.attention_processor import (
AttnAddedKVProcessor,
@@ -29,6 +30,7 @@
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
+ LoRALinearLayer,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
@@ -36,7 +38,6 @@
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
- TEXT_ENCODER_ATTN_MODULE,
_get_model_file,
deprecate,
is_safetensors_available,
@@ -49,7 +50,7 @@
import safetensors
if is_transformers_available():
- from transformers import PreTrainedModel, PreTrainedTokenizer
+ from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
logger = logging.get_logger(__name__)
@@ -67,6 +68,64 @@
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
+class PatchedLoraProjection(nn.Module):
+ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
+ super().__init__()
+ self.regular_linear_layer = regular_linear_layer
+
+ device = self.regular_linear_layer.weight.device
+
+ if dtype is None:
+ dtype = self.regular_linear_layer.weight.dtype
+
+ self.lora_linear_layer = LoRALinearLayer(
+ self.regular_linear_layer.in_features,
+ self.regular_linear_layer.out_features,
+ network_alpha=network_alpha,
+ device=device,
+ dtype=dtype,
+ rank=rank,
+ )
+
+ self.lora_scale = lora_scale
+
+ def forward(self, input):
+ return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
+
+
+def text_encoder_attn_modules(text_encoder):
+ attn_modules = []
+
+ if isinstance(text_encoder, CLIPTextModel):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+ else:
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
+
+ return attn_modules
+
+
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
@@ -744,9 +803,48 @@ class LoraLoaderMixin:
unet_name = UNET_NAME
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
+ `self.unet`.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
+ into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+
+ kwargs:
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+ """
+ state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
+ self.load_lora_into_text_encoder(
+ state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale
+ )
+
+ @classmethod
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
r"""
- Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+ Return state dict for lora weights
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -801,9 +899,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
- # set lora scale to a reasonable default
- self._lora_scale = 1.0
-
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
@@ -840,7 +935,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
- except IOError as e:
+ except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
@@ -866,286 +961,185 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
- state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
+ state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
+
+ return state_dict, network_alpha
+
+ @classmethod
+ def load_lora_into_unet(cls, state_dict, network_alpha, unet):
+ """
+ This will load the LoRA layers specified in `state_dict` into `unet`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alpha (`float`):
+ See `LoRALinearLayer` for more details.
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ """
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
- if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
+ if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
- unet_keys = [k for k in keys if k.startswith(self.unet_name)]
- logger.info(f"Loading {self.unet_name}.")
+ unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
+ logger.info(f"Loading {cls.unet_name}.")
unet_lora_state_dict = {
- k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
+ k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
- self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
-
- # Load the layers corresponding to text encoder and make necessary adjustments.
- text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
- text_encoder_lora_state_dict = {
- k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
- }
- if len(text_encoder_lora_state_dict) > 0:
- logger.info(f"Loading {self.text_encoder_name}.")
- attn_procs_text_encoder = self._load_text_encoder_attn_procs(
- text_encoder_lora_state_dict, network_alpha=network_alpha
- )
- self._modify_text_encoder(attn_procs_text_encoder)
-
- # save lora attn procs of text encoder so that it can be easily retrieved
- self._text_encoder_lora_attn_procs = attn_procs_text_encoder
+ unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all(
- key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
+ key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys()
):
- self.unet.load_attn_procs(state_dict)
+ unet.load_attn_procs(state_dict)
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
+ @classmethod
+ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key shoult be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alpha (`float`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ """
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {cls.text_encoder_name}.")
+
+ if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
+ # Convert from the old naming convention to the new naming convention.
+ #
+ # Previously, the old LoRA layers were stored on the state dict at the
+ # same level as the attention block i.e.
+ # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
+ #
+ # This is no actual module at that point, they were monkey patched on to the
+ # existing module. We want to be able to load them via their actual state dict.
+ # They're in `PatchedLoraProjection.lora_linear_layer` now.
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ text_encoder_lora_state_dict[
+ f"{name}.q_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.k_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.v_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.out_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
+
+ text_encoder_lora_state_dict[
+ f"{name}.q_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.k_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.v_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.out_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
+
+ rank = text_encoder_lora_state_dict[
+ "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
+ ].shape[1]
+
+ cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
+
+ # set correct dtype & device
+ text_encoder_lora_state_dict = {
+ k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
+ for k, v in text_encoder_lora_state_dict.items()
+ }
+
+ load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
+ if len(load_state_dict_results.unexpected_keys) != 0:
+ raise ValueError(
+ f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
+ )
+
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
- @property
- def text_encoder_lora_attn_procs(self):
- if hasattr(self, "_text_encoder_lora_attn_procs"):
- return self._text_encoder_lora_attn_procs
- return
-
def _remove_text_encoder_monkey_patch(self):
- # Loop over the CLIPAttention module of text_encoder
- for name, attn_module in self.text_encoder.named_modules():
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
- # Loop over the LoRA layers
- for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
- # Retrieve the q/k/v/out projection of CLIPAttention
- module = attn_module.get_submodule(text_encoder_attr)
- if hasattr(module, "old_forward"):
- # restore original `forward` to remove monkey-patch
- module.forward = module.old_forward
- delattr(module, "old_forward")
-
- def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+
+ @classmethod
+ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
+ attn_module.q_proj = attn_module.q_proj.regular_linear_layer
+ attn_module.k_proj = attn_module.k_proj.regular_linear_layer
+ attn_module.v_proj = attn_module.v_proj.regular_linear_layer
+ attn_module.out_proj = attn_module.out_proj.regular_linear_layer
+
+ @classmethod
+ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
-
- Parameters:
- attn_processors: Dict[str, `LoRAAttnProcessor`]:
- A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
"""
# First, remove any monkey-patch that might have been applied before
- self._remove_text_encoder_monkey_patch()
-
- # Loop over the CLIPAttention module of text_encoder
- for name, attn_module in self.text_encoder.named_modules():
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
- # Loop over the LoRA layers
- for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
- # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
- module = attn_module.get_submodule(text_encoder_attr)
- lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
-
- # save old_forward to module that can be used to remove monkey-patch
- old_forward = module.old_forward = module.forward
-
- # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
- # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
- def make_new_forward(old_forward, lora_layer):
- def new_forward(x):
- result = old_forward(x) + self.lora_scale * lora_layer(x)
- return result
-
- return new_forward
-
- # Monkey-patch.
- module.forward = make_new_forward(old_forward, lora_layer)
-
- @property
- def _lora_attn_processor_attr_to_text_encoder_attr(self):
- return {
- "to_q_lora": "q_proj",
- "to_k_lora": "k_proj",
- "to_v_lora": "v_proj",
- "to_out_lora": "out_proj",
- }
-
- def _load_text_encoder_attn_procs(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
- ):
- r"""
- Load pretrained attention processor layers for
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
-
-
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
- `./my_model_directory/`.
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- resume_download (`bool`, *optional*, defaults to `False`):
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
- file exists.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether or not to only look at local files (i.e., do not try to download the model).
- use_auth_token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `diffusers-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo (either remote in
- huggingface.co or downloaded locally), you can specify the folder name here.
- mirror (`str`, *optional*):
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
- Please refer to the mirror site for more information.
-
- Returns:
- `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
- [`LoRAAttnProcessor`].
-
-
+ cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
- """
+ lora_parameters = []
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
- force_download = kwargs.pop("force_download", False)
- resume_download = kwargs.pop("resume_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
- use_auth_token = kwargs.pop("use_auth_token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
- network_alpha = kwargs.pop("network_alpha", None)
-
- if use_safetensors and not is_safetensors_available():
- raise ValueError(
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ attn_module.q_proj = PatchedLoraProjection(
+ attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
)
+ lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = is_safetensors_available()
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- model_file = None
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- # Let's first try to load .safetensors weights
- if (use_safetensors and weight_name is None) or (
- weight_name is not None and weight_name.endswith(".safetensors")
- ):
- try:
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
- except IOError as e:
- if not allow_pickle:
- raise e
- # try loading non-safetensors weights
- pass
- if model_file is None:
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name or LORA_WEIGHT_NAME,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- state_dict = torch.load(model_file, map_location="cpu")
- else:
- state_dict = pretrained_model_name_or_path_or_dict
-
- # fill attn processors
- attn_processors = {}
-
- is_lora = all("lora" in k for k in state_dict.keys())
-
- if is_lora:
- lora_grouped_dict = defaultdict(dict)
- for key, value in state_dict.items():
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
- lora_grouped_dict[attn_processor_key][sub_key] = value
-
- for key, value_dict in lora_grouped_dict.items():
- rank = value_dict["to_k_lora.down.weight"].shape[0]
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
+ attn_module.k_proj = PatchedLoraProjection(
+ attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
+ )
+ lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
- attn_processor_class = (
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
- )
- attn_processors[key] = attn_processor_class(
- hidden_size=hidden_size,
- cross_attention_dim=cross_attention_dim,
- rank=rank,
- network_alpha=network_alpha,
- )
- attn_processors[key].load_state_dict(value_dict)
+ attn_module.v_proj = PatchedLoraProjection(
+ attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
+ )
+ lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
- else:
- raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
+ attn_module.out_proj = PatchedLoraProjection(
+ attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
+ )
+ lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
- # set correct dtype & device
- attn_processors = {
- k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
- }
- return attn_processors
+ return lora_parameters
@classmethod
def save_lora_weights(
@@ -1225,7 +1219,8 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
- def _convert_kohya_lora_to_diffusers(self, state_dict):
+ @classmethod
+ def _convert_kohya_lora_to_diffusers(cls, state_dict):
unet_state_dict = {}
te_state_dict = {}
network_alpha = None
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 5b6a161f8466..da2920fa671a 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -506,14 +506,14 @@ def __call__(
class LoRALinearLayer(nn.Module):
- def __init__(self, in_features, out_features, rank=4, network_alpha=None):
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
- self.down = nn.Linear(in_features, rank, bias=False)
- self.up = nn.Linear(rank, out_features, bias=False)
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 7449df99ba80..98fac64497e7 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -30,7 +30,6 @@
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
- TEXT_ENCODER_ATTN_MODULE,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 3c641a259a81..b9e60a2a873b 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -30,4 +30,3 @@
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
-TEXT_ENCODER_ATTN_MODULE = ".self_attn"
diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py
index aaacf1e68f9f..3190a123898c 100644
--- a/tests/models/test_lora_layers.py
+++ b/tests/models/test_lora_layers.py
@@ -12,18 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
import os
import tempfile
import unittest
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
+from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
+from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
@@ -33,7 +34,8 @@
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
-from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device
+from diffusers.utils import floats_tensor, torch_device
+from diffusers.utils.testing_utils import require_torch_gpu, slow
def create_unet_lora_layers(unet: nn.Module):
@@ -63,11 +65,15 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
- for name, module in text_encoder.named_modules():
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
- text_lora_attn_procs[name] = lora_attn_processor_class(
- hidden_size=module.out_proj.out_features, cross_attention_dim=None
- )
+ for name, module in text_encoder_attn_modules(text_encoder):
+ if isinstance(module.out_proj, nn.Linear):
+ out_features = module.out_proj.out_features
+ elif isinstance(module.out_proj, PatchedLoraProjection):
+ out_features = module.out_proj.regular_linear_layer.out_features
+ else:
+ assert False, module.out_proj.__class__
+
+ text_lora_attn_procs[name] = lora_attn_processor_class(hidden_size=out_features, cross_attention_dim=None)
return text_lora_attn_procs
@@ -77,17 +83,13 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
return text_encoder_lora_layers
-def set_lora_up_weights(text_lora_attn_procs, randn_weight=False):
- for _, attn_proc in text_lora_attn_procs.items():
- # set up.weights
- for layer_name, layer_module in attn_proc.named_modules():
- if layer_name.endswith("_lora"):
- weight = (
- torch.randn_like(layer_module.up.weight)
- if randn_weight
- else torch.zeros_like(layer_module.up.weight)
- )
- layer_module.up.weight = torch.nn.Parameter(weight)
+def set_lora_weights(text_lora_attn_parameters, randn_weight=False):
+ with torch.no_grad():
+ for parameter in text_lora_attn_parameters:
+ if randn_weight:
+ parameter[:] = torch.randn_like(parameter)
+ else:
+ torch.zero_(parameter)
class LoraLoaderMixinTests(unittest.TestCase):
@@ -281,16 +283,10 @@ def test_text_encoder_lora_monkey_patch(self):
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)
- # create lora_attn_procs with zeroed out up.weights
- text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
- set_lora_up_weights(text_attn_procs, randn_weight=False)
-
# monkey patch
- pipe._modify_text_encoder(text_attn_procs)
+ params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
- # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
- del text_attn_procs
- gc.collect()
+ set_lora_weights(params, randn_weight=False)
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
@@ -301,15 +297,12 @@ def test_text_encoder_lora_monkey_patch(self):
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"
# create lora_attn_procs with randn up.weights
- text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
- set_lora_up_weights(text_attn_procs, randn_weight=True)
+ create_text_encoder_lora_attn_procs(pipe.text_encoder)
# monkey patch
- pipe._modify_text_encoder(text_attn_procs)
+ params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
- # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
- del text_attn_procs
- gc.collect()
+ set_lora_weights(params, randn_weight=True)
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
@@ -329,16 +322,10 @@ def test_text_encoder_lora_remove_monkey_patch(self):
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)
- # create lora_attn_procs with randn up.weights
- text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
- set_lora_up_weights(text_attn_procs, randn_weight=True)
-
# monkey patch
- pipe._modify_text_encoder(text_attn_procs)
+ params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
- # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
- del text_attn_procs
- gc.collect()
+ set_lora_weights(params, randn_weight=True)
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
@@ -467,3 +454,86 @@ def test_lora_save_load_with_xformers(self):
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
+
+
+@slow
+@require_torch_gpu
+class LoraIntegrationTests(unittest.TestCase):
+ def test_dreambooth_old_format(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example"
+ card = RepoCard.load(lora_model_id)
+ base_model_id = card.data.to_dict()["base_model"]
+
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
+ pipe = pipe.to(torch_device)
+ pipe.load_lora_weights(lora_model_id)
+
+ images = pipe(
+ "A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+
+ images = images[0, -3:, -3:, -1].flatten()
+
+ expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785])
+
+ self.assertTrue(np.allclose(images, expected, atol=1e-4))
+
+ def test_dreambooth_text_encoder_new_format(self):
+ generator = torch.Generator().manual_seed(0)
+
+ lora_model_id = "hf-internal-testing/lora-trained"
+ card = RepoCard.load(lora_model_id)
+ base_model_id = card.data.to_dict()["base_model"]
+
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
+ pipe = pipe.to(torch_device)
+ pipe.load_lora_weights(lora_model_id)
+
+ images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images
+
+ images = images[0, -3:, -3:, -1].flatten()
+
+ expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359])
+
+ self.assertTrue(np.allclose(images, expected, atol=1e-4))
+
+ def test_a1111(self):
+ generator = torch.Generator().manual_seed(0)
+
+ pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to(
+ torch_device
+ )
+ lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
+ lora_filename = "light_and_shadow.safetensors"
+ pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+
+ images = pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+
+ images = images[0, -3:, -3:, -1].flatten()
+
+ expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245])
+
+ self.assertTrue(np.allclose(images, expected, atol=1e-4))
+
+ def test_vanilla_funetuning(self):
+ generator = torch.Generator().manual_seed(0)
+
+ lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4"
+ card = RepoCard.load(lora_model_id)
+ base_model_id = card.data.to_dict()["base_model"]
+
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
+ pipe = pipe.to(torch_device)
+ pipe.load_lora_weights(lora_model_id)
+
+ images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images
+
+ images = images[0, -3:, -3:, -1].flatten()
+
+ expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
+
+ self.assertTrue(np.allclose(images, expected, atol=1e-4))