diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 8b42f412add1..94910956d022 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -30,6 +30,7 @@ _get_model_file, convert_state_dict_to_diffusers, convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, get_peft_kwargs, @@ -1543,6 +1544,11 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, } if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + if adapter_name in getattr(transformer, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index e76c65778174..2589625ec6ac 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -27,7 +27,7 @@ SD3Transformer2DModel, StableDiffusion3Pipeline, ) -from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, torch_device +from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device if is_peft_available(): @@ -287,3 +287,24 @@ def test_simple_inference_with_transformer_fuse_unfuse(self): self.assertTrue( np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) + + @require_torch_gpu + def test_sd3_lora(self): + """ + Test loading the loras that are saved with the diffusers and peft formats. + Related PR: https://github.com/huggingface/diffusers/pull/8584 + """ + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + lora_model_id = "hf-internal-testing/tiny-sd3-loras" + + lora_filename = "lora_diffusers_format.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.unload_lora_weights() + + lora_filename = "lora_peft_format.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)