Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 16 additions & 41 deletions src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...models.normalization import RMSNorm
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn
from ..modeling_outputs import Transformer2DModelOutput


ADALN_EMBED_DIM = 256
Expand All @@ -39,17 +40,9 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size,
mid_size,
bias=True,
),
nn.Linear(frequency_embedding_size, mid_size, bias=True),
nn.SiLU(),
nn.Linear(
mid_size,
out_size,
bias=True,
),
nn.Linear(mid_size, out_size, bias=True),
)

self.frequency_embedding_size = frequency_embedding_size
Expand Down Expand Up @@ -211,9 +204,7 @@ def __init__(

self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
)
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))

def forward(
self,
Expand All @@ -230,33 +221,19 @@ def forward(

# Attention block
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
)
x = x + gate_msa * self.attention_norm2(attn_out)

# FFN block
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
else:
# Attention block
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
x = x + self.attention_norm2(attn_out)

# FFN block
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))

return x

Expand Down Expand Up @@ -404,10 +381,7 @@ def __init__(
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))

self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
Expand Down Expand Up @@ -494,11 +468,8 @@ def patchify_and_embed(
)

# padded feature
cap_padded_feat = torch.cat(
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
dim=0,
)
all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat)
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
all_cap_feats_out.append(cap_padded_feat)

### Process Image
C, F, H, W = image.size()
Expand Down Expand Up @@ -564,6 +535,7 @@ def forward(
cap_feats: List[torch.Tensor],
patch_size=2,
f_patch_size=1,
return_dict: bool = True,
):
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
Expand Down Expand Up @@ -672,4 +644,7 @@ def forward(
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)

return x, {}
if not return_dict:
return (x,)

return Transformer2DModelOutput(sample=x)
4 changes: 1 addition & 3 deletions src/diffusers/pipelines/z_image/pipeline_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,7 @@ def __call__(
latent_model_input_list = list(latent_model_input.unbind(dim=0))

model_out_list = self.transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
)[0]

if apply_cfg:
Expand Down
146 changes: 134 additions & 12 deletions tests/lora/test_lora_layers_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@
import sys
import unittest

import numpy as np
import torch
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model

from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
ZImagePipeline,
ZImageTransformer2DModel,
)
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel

from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device


if is_peft_available():
Expand All @@ -34,13 +30,9 @@

sys.path.append(".")

from .utils import PeftLoraLoaderMixinTests # noqa: E402
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402


@unittest.skip(
"ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
"and torch.empty padding tokens. LoRA functionality works correctly with real models."
)
Comment on lines -40 to -43
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@require_peft_backend
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = ZImagePipeline
Expand Down Expand Up @@ -127,6 +119,12 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)

transformer = self.transformer_cls(**self.transformer_kwargs)
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
# This can cause NaN data values in our testing environment. Fixating them
# helps prevent that issue.
with torch.no_grad():
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
vae = self.vae_cls(**self.vae_kwargs)

if scheduler_cls is None:
Expand Down Expand Up @@ -161,3 +159,127 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
}

return pipeline_components, text_lora_config, denoiser_lora_config

def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]

pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")

lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]

pipe.transformer.delete_adapters("adapter-1")

denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, _ in denoiser.named_modules():
if "to_k" in name and "attention" in name and "lora" not in name:
module_name_to_rank_update = name.replace(".base_layer.", ".")
break

# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}

pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern

self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})

lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))

pipe.transformer.delete_adapters("adapter-1")

# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}

pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)

lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

@skip_mps
def test_lora_fuse_nan(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

# corrupt one LoRA weight with `inf` values
with torch.no_grad():
possible_tower_names = ["noise_refiner"]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")

# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)

# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]

self.assertTrue(np.isnan(out).all())

def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)

@unittest.skip("Needs to be debugged.")
def test_set_adapters_match_attention_kwargs(self):
super().test_set_adapters_match_attention_kwargs()

@unittest.skip("Needs to be debugged.")
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()

@unittest.skip("Not supported in ZImage.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

@unittest.skip("Not supported in ZImage.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass

@unittest.skip("Not supported in ZImage.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_save_load(self):
pass
Loading
Loading