-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Z-Image] various small changes, Z-Image transformer tests, etc. #12741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
12608de
start zimage model tests.
sayakpaul 6d47d10
up
sayakpaul 1b0888c
up
sayakpaul 7c47ae0
up
sayakpaul d54bd6c
up
sayakpaul 9b0028a
up
sayakpaul a74a8f7
up
sayakpaul c137ae1
up
sayakpaul 66b6922
Merge branch 'main' into z-image-tests
sayakpaul 2c367f8
up
sayakpaul 76dbf63
up
sayakpaul 8ee24fc
up
sayakpaul a11cdd2
up
sayakpaul bca3e27
up
sayakpaul 52c6d2f
Revert "up"
sayakpaul 3eef952
Merge branch 'main' into z-image-tests
sayakpaul 91a8c2a
expand upon compilation failure reason.
sayakpaul 1513e52
Update tests/models/transformers/test_models_transformer_z_image.py
sayakpaul e4702dc
up
sayakpaul cf34435
reinitialize the padding tokens to ones to prevent NaN problems.
sayakpaul a538b7a
updates
sayakpaul 5671c1d
Merge branch 'main' into z-image-tests
sayakpaul 1e4e272
Merge branch 'main' into z-image-tests
sayakpaul d910a26
up
sayakpaul 4ca68f2
skipping ZImage DiT tests
sayakpaul 5613ff0
up
sayakpaul 21d8a8d
Merge branch 'main' into z-image-tests
sayakpaul 6da67f2
up
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,17 +15,13 @@ | |
| import sys | ||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model | ||
|
|
||
| from diffusers import ( | ||
| AutoencoderKL, | ||
| FlowMatchEulerDiscreteScheduler, | ||
| ZImagePipeline, | ||
| ZImageTransformer2DModel, | ||
| ) | ||
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel | ||
|
|
||
| from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend | ||
| from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device | ||
|
|
||
|
|
||
| if is_peft_available(): | ||
|
|
@@ -34,13 +30,9 @@ | |
|
|
||
| sys.path.append(".") | ||
|
|
||
| from .utils import PeftLoraLoaderMixinTests # noqa: E402 | ||
| from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 | ||
|
|
||
|
|
||
| @unittest.skip( | ||
| "ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations " | ||
| "and torch.empty padding tokens. LoRA functionality works correctly with real models." | ||
| ) | ||
|
Comment on lines
-40
to
-43
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| @require_peft_backend | ||
| class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): | ||
| pipeline_class = ZImagePipeline | ||
|
|
@@ -127,6 +119,12 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No | |
| tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id) | ||
|
|
||
| transformer = self.transformer_cls(**self.transformer_kwargs) | ||
| # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`. | ||
| # This can cause NaN data values in our testing environment. Fixating them | ||
| # helps prevent that issue. | ||
| with torch.no_grad(): | ||
| transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data)) | ||
| transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data)) | ||
| vae = self.vae_cls(**self.vae_kwargs) | ||
|
|
||
| if scheduler_cls is None: | ||
|
|
@@ -161,3 +159,127 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No | |
| } | ||
|
|
||
| return pipeline_components, text_lora_config, denoiser_lora_config | ||
|
|
||
| def test_correct_lora_configs_with_different_ranks(self): | ||
| components, _, denoiser_lora_config = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| pipe = pipe.to(torch_device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
|
|
||
| pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") | ||
|
|
||
| lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
|
|
||
| pipe.transformer.delete_adapters("adapter-1") | ||
|
|
||
| denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer | ||
| for name, _ in denoiser.named_modules(): | ||
| if "to_k" in name and "attention" in name and "lora" not in name: | ||
| module_name_to_rank_update = name.replace(".base_layer.", ".") | ||
| break | ||
|
|
||
| # change the rank_pattern | ||
| updated_rank = denoiser_lora_config.r * 2 | ||
| denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} | ||
|
|
||
| pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") | ||
| updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern | ||
|
|
||
| self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) | ||
|
|
||
| lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) | ||
| self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) | ||
|
|
||
| pipe.transformer.delete_adapters("adapter-1") | ||
|
|
||
| # similarly change the alpha_pattern | ||
| updated_alpha = denoiser_lora_config.lora_alpha * 2 | ||
| denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} | ||
|
|
||
| pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") | ||
| self.assertTrue( | ||
| pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} | ||
| ) | ||
|
|
||
| lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) | ||
| self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) | ||
|
|
||
| @skip_mps | ||
| def test_lora_fuse_nan(self): | ||
| components, _, denoiser_lora_config = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| pipe = pipe.to(torch_device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet | ||
| denoiser.add_adapter(denoiser_lora_config, "adapter-1") | ||
| self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") | ||
|
|
||
| # corrupt one LoRA weight with `inf` values | ||
| with torch.no_grad(): | ||
| possible_tower_names = ["noise_refiner"] | ||
| filtered_tower_names = [ | ||
| tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) | ||
| ] | ||
| for tower_name in filtered_tower_names: | ||
| transformer_tower = getattr(pipe.transformer, tower_name) | ||
| transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf") | ||
|
|
||
| # with `safe_fusing=True` we should see an Error | ||
| with self.assertRaises(ValueError): | ||
| pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) | ||
|
|
||
| # without we should not see an error, but every image will be black | ||
| pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) | ||
| out = pipe(**inputs)[0] | ||
|
|
||
| self.assertTrue(np.isnan(out).all()) | ||
|
|
||
| def test_lora_scale_kwargs_match_fusion(self): | ||
| super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2) | ||
|
|
||
| @unittest.skip("Needs to be debugged.") | ||
| def test_set_adapters_match_attention_kwargs(self): | ||
| super().test_set_adapters_match_attention_kwargs() | ||
|
|
||
| @unittest.skip("Needs to be debugged.") | ||
| def test_simple_inference_with_text_denoiser_lora_and_scale(self): | ||
| super().test_simple_inference_with_text_denoiser_lora_and_scale() | ||
|
|
||
| @unittest.skip("Not supported in ZImage.") | ||
| def test_simple_inference_with_text_denoiser_block_scale(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Not supported in ZImage.") | ||
| def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Not supported in ZImage.") | ||
| def test_modify_padding_mode(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Text encoder LoRA is not supported in ZImage.") | ||
| def test_simple_inference_with_partial_text_lora(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Text encoder LoRA is not supported in ZImage.") | ||
| def test_simple_inference_with_text_lora(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Text encoder LoRA is not supported in ZImage.") | ||
| def test_simple_inference_with_text_lora_and_scale(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Text encoder LoRA is not supported in ZImage.") | ||
| def test_simple_inference_with_text_lora_fused(self): | ||
| pass | ||
|
|
||
| @unittest.skip("Text encoder LoRA is not supported in ZImage.") | ||
| def test_simple_inference_with_text_lora_save_load(self): | ||
| pass | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.