From 51dc061ee664e6aa81e9669c42cbe541ce412426 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 10 Jan 2026 06:15:36 +0100 Subject: [PATCH 1/4] Improve incorrect LoRA format error message --- src/diffusers/loaders/lora_pipeline.py | 38 +++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 03a2fe9f3f8e..06408a6764ba 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -212,7 +212,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_unet( state_dict, @@ -639,7 +639,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_unet( state_dict, @@ -1079,7 +1079,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -1375,7 +1375,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -1657,7 +1657,7 @@ def load_lora_weights( ) if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") transformer_lora_state_dict = { k: state_dict.get(k) @@ -2504,7 +2504,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -2701,7 +2701,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -2904,7 +2904,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -3104,7 +3104,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -3307,7 +3307,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -3511,7 +3511,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -3711,7 +3711,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -3965,7 +3965,7 @@ def load_lora_weights( ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4242,7 +4242,7 @@ def load_lora_weights( ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4462,7 +4462,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -4665,7 +4665,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -4871,7 +4871,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -5077,7 +5077,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, @@ -5280,7 +5280,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") self.load_lora_into_transformer( state_dict, From dc43efbc4c938d2c9e64803aff2dc2c404cb9140 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 10 Jan 2026 07:54:27 +0100 Subject: [PATCH 2/4] Add flag in PeftLoraLoaderMixinTests to disable text encoder LoRA tests --- tests/lora/test_lora_layers_auraflow.py | 22 ++------------------- tests/lora/test_lora_layers_cogvideox.py | 22 ++------------------- tests/lora/test_lora_layers_cogview4.py | 22 ++------------------- tests/lora/test_lora_layers_flux2.py | 22 ++------------------- tests/lora/test_lora_layers_hunyuanvideo.py | 22 ++------------------- tests/lora/test_lora_layers_ltx_video.py | 22 ++------------------- tests/lora/test_lora_layers_lumina2.py | 22 ++------------------- tests/lora/test_lora_layers_mochi.py | 22 ++------------------- tests/lora/test_lora_layers_qwenimage.py | 22 ++------------------- tests/lora/test_lora_layers_sana.py | 22 ++------------------- tests/lora/test_lora_layers_wan.py | 22 ++------------------- tests/lora/test_lora_layers_wanvace.py | 22 ++------------------- tests/lora/test_lora_layers_z_image.py | 22 ++------------------- tests/lora/utils.py | 19 ++++++++++++++++++ 14 files changed, 45 insertions(+), 260 deletions(-) diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 91f63c4b56c4..78ef4ce151be 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -114,23 +116,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index fa57b4c9c2f9..7bd54b77ca35 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -147,26 +149,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 30eb8fbb6367..e8ee6e7a7db6 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -162,23 +164,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py index 4ae189aceb66..d970b7d7847f 100644 --- a/tests/lora/test_lora_layers_flux2.py +++ b/tests/lora/test_lora_layers_flux2.py @@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers" denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -146,23 +148,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in Flux2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index cfd5d3146a91..e59bc5662fe1 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder_2", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -172,26 +174,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @nightly @require_torch_accelerator diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e513f..095e5b577cf0 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -125,23 +127,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 0417b05b33a1..da032229a785 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 4, 4, 3) @@ -113,26 +115,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 7be81273db77..ee8254112924 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 7, 16, 16, 3) @@ -117,26 +119,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 51de2f8e20e1..73fd026a670c 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ) denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -107,23 +109,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in Qwen Image.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index a860b7b44f2c..97bf5cbba920 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -117,26 +119,6 @@ def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") def test_layerwise_casting_inference_denoiser(self): return super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 5734509b410f..5ae16ab4b9da 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -121,23 +123,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ab1f57bfc9da..c8acaea9bef0 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -139,26 +141,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_save_load(self): - pass - def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py index 35d1389d9612..8432ea56a6fa 100644 --- a/tests/lora/test_lora_layers_z_image.py +++ b/tests/lora/test_lora_layers_z_image.py @@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -263,23 +265,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in ZImage.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 5fae6cac0a7f..efa49b9f4838 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests: tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" + supports_text_encoder_loras = True unet_kwargs = None transformer_cls = None @@ -333,6 +334,9 @@ def test_simple_inference_with_text_lora(self): Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -457,6 +461,9 @@ def test_simple_inference_with_text_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -494,6 +501,9 @@ def test_simple_inference_with_text_lora_fused(self): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -555,6 +565,9 @@ def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -593,6 +606,9 @@ def test_simple_inference_with_partial_text_lora(self): with different ranks and some adapters removed and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, _, _ = self.get_dummy_components() # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( @@ -651,6 +667,9 @@ def test_simple_inference_save_pretrained_with_text_lora(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) From 0b107461404d14cb5ef98d8a04d310d649e2e8e0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 10 Jan 2026 08:02:24 +0100 Subject: [PATCH 3/4] Apply changes to LTX2LoraTests --- tests/lora/test_lora_layers_ltx2.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/tests/lora/test_lora_layers_ltx2.py b/tests/lora/test_lora_layers_ltx2.py index 886ae70b7d46..0a4b14454f5b 100644 --- a/tests/lora/test_lora_layers_ltx2.py +++ b/tests/lora/test_lora_layers_ltx2.py @@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): denoiser_target_modules = ["to_q", "to_k", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 5, 32, 32, 3) @@ -267,27 +269,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in LTX2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTX2.") - def test_simple_inference_save_pretrained_with_text_lora(self): - pass From 150047f51a7ad4c361935a5f0749f13eefbb1d45 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 12 Jan 2026 23:59:44 +0100 Subject: [PATCH 4/4] Further improve incorrect LoRA format error msg following review --- src/diffusers/loaders/lora_pipeline.py | 40 +++++++++++++------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b76fe0a0e0dc..24d1fd7b9308 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -214,7 +214,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -641,7 +641,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -1081,7 +1081,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1377,7 +1377,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1659,7 +1659,7 @@ def load_lora_weights( ) if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_lora_state_dict = { k: state_dict.get(k) @@ -2506,7 +2506,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2703,7 +2703,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2906,7 +2906,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3115,7 +3115,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_peft_state_dict = { k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.") @@ -3333,7 +3333,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3536,7 +3536,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3740,7 +3740,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3940,7 +3940,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4194,7 +4194,7 @@ def load_lora_weights( ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4471,7 +4471,7 @@ def load_lora_weights( ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4691,7 +4691,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4894,7 +4894,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5100,7 +5100,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5306,7 +5306,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5509,7 +5509,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `lora`.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict,