diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7791282ebdb9..06174027a2eb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -83,6 +83,7 @@ "ControlNetModel", "ControlNetXSAdapter", "DiTTransformer2DModel", + "HunyuanDiT2DModel", "I2VGenXLUNet", "Kandinsky3UNet", "ModelMixin", @@ -229,6 +230,7 @@ "BlipDiffusionPipeline", "CLIPImageProjection", "CycleDiffusionPipeline", + "HunyuanDiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -487,6 +489,7 @@ ControlNetModel, ControlNetXSAdapter, DiTTransformer2DModel, + HunyuanDiT2DModel, I2VGenXLUNet, Kandinsky3UNet, ModelMixin, @@ -611,6 +614,7 @@ AudioLDMPipeline, CLIPImageProjection, CycleDiffusionPipeline, + HunyuanDiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 52eda31acc76..f50670821b93 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -38,6 +38,7 @@ _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] + _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] @@ -78,6 +79,7 @@ from .transformers import ( DiTTransformer2DModel, DualTransformer2DModel, + HunyuanDiT2DModel, PixArtTransformer2DModel, PriorTransformer, T5FilmDecoder, diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index f94b6c8d6d06..28ee92ddb2e3 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -50,6 +50,18 @@ def get_activation(act_fn: str) -> nn.Module: raise ValueError(f"Unsupported activation function: {act_fn}") +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cbb07eafa37f..de4b0722591d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -103,6 +103,7 @@ def __init__( upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, @@ -161,6 +162,15 @@ def __init__( else: self.spatial_norm = None + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": @@ -1426,6 +1436,104 @@ def __call__( return hidden_states +class HunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d2940e861c83..98178dad5e23 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -16,10 +16,11 @@ import numpy as np import torch +import torch.nn.functional as F from torch import nn from ..utils import deprecate -from .activations import get_activation +from .activations import FP32SiLU, get_activation from .attention_processor import Attention @@ -135,6 +136,7 @@ def __init__( flatten=True, bias=True, interpolation_scale=1, + pos_embed_type="sincos", ): super().__init__() @@ -156,10 +158,18 @@ def __init__( self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size self.interpolation_scale = interpolation_scale - pos_embed = get_2d_sincos_pos_embed( - embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale - ) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + if pos_embed_type is None: + self.pos_embed = None + elif pos_embed_type == "sincos": + pos_embed = get_2d_sincos_pos_embed( + embed_dim, + int(num_patches**0.5), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + else: + raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") def forward(self, latent): height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size @@ -169,6 +179,8 @@ def forward(self, latent): latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) + if self.pos_embed is None: + return latent.to(latent.dtype) # Interpolate positional embeddings if needed. # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) @@ -187,6 +199,113 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) +def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + + class TimestepEmbedding(nn.Module): def __init__( self, @@ -507,6 +626,88 @@ def forward(self, timestep, class_labels, hidden_dtype=None): return conditioning +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + # extra condition2: image meta size embdding + image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning + + class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() @@ -793,11 +994,18 @@ class PixArtAlphaTextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_features, hidden_size, num_tokens=120): + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): super().__init__() + if out_features is None: + out_features = hidden_size self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) - self.act_1 = nn.GELU(approximate="tanh") - self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu_fp32": + self.act_1 = FP32SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) def forward(self, caption): hidden_states = self.linear_1(caption) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 036a66890e67..a30c09277063 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -176,7 +176,8 @@ def __init__( raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: - emb = self.linear(self.silu(conditioning_embedding)) + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index f668bc664238..c0ab665493aa 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -4,6 +4,7 @@ if is_torch_available(): from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel + from .hunyuan_transformer_2d import HunyuanDiT2DModel from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py new file mode 100644 index 000000000000..1f859ef9bd94 --- /dev/null +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -0,0 +1,427 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention, HunyuanAttnProcessor2_0 +from ..embeddings import ( + HunyuanCombinedTimestepTextSizeStyleEmbedding, + PatchEmbed, + PixArtAlphaTextProjection, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps + ).to(origin_dtype) + + +class AdaLayerNormShift(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim) + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) + x = self.norm(x) + shift.unsqueeze(dim=1) + return x + + +@maybe_allow_in_graph +class HunyuanDiTBlock(nn.Module): + r""" + Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and + QKNorm + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of headsto use for multi-head attention. + cross_attention_dim (`int`,*optional*): + The size of the encoder_hidden_states vector for cross attention. + dropout(`float`, *optional*, defaults to 0.0): + The dropout probability to use. + activation_fn (`str`,*optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. . + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, *optional*, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): + The size of the hidden layer in the feed-forward block. Defaults to `None`. + ff_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the feed-forward block. + skip (`bool`, *optional*, defaults to `False`): + Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use normalization in QK calculation. Defaults to `True`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + cross_attention_dim: int = 1024, + dropout=0.0, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-6, + final_dropout: bool = False, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + skip: bool = False, + qk_norm: bool = True, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # NOTE: when new version comes, check norm2 and norm 3 + # 1. Self-Attn + self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + # 3. Feed-forward + self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, ### 0.0 + activation_fn=activation_fn, ### approx GeLU + final_dropout=final_dropout, ### 0.0 + inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) + bias=ff_bias, + ) + + # 4. Skip Connection + if skip: + self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb=None, + skip=None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([hidden_states, skip], dim=-1) + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + attn_output + + # 2. Cross-Attention + hidden_states = hidden_states + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # FFN Layer ### TODO: switch norm2 and norm3 in the state dict + mlp_inputs = self.norm3(hidden_states) + hidden_states = hidden_states + self.ff(mlp_inputs) + + return hidden_states + + +class HunyuanDiT2DModel(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + patch_size (`int`, *optional*): + The size of the patch to use for the input. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. + sample_size (`int`, *optional*): + The width of the latent images. This is fixed during training since it is used to learn a number of + position embeddings. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The number of dimension in the clip text embedding. + hidden_size (`int`, *optional*): + The size of hidden layer in the conditioning embedding layers. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden layer size to the input size. + learn_sigma (`bool`, *optional*, defaults to `True`): + Whether to predict variance. + cross_attention_dim_t5 (`int`, *optional*): + The number dimensions in t5 text embedding. + pooled_projection_dim (`int`, *optional*): + The size of the pooled projection. + text_len (`int`, *optional*): + The length of the clip text embedding. + text_len_t5 (`int`, *optional*): + The length of the T5 text embedding. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "gelu-approximate", + sample_size=32, + hidden_size=1152, + num_layers: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = True, + cross_attention_dim: int = 1024, + norm_type: str = "layer_norm", + cross_attention_dim_t5: int = 2048, + pooled_projection_dim: int = 1024, + text_len: int = 77, + text_len_t5: int = 256, + ): + super().__init__() + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.num_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + + self.text_embedder = PixArtAlphaTextProjection( + in_features=cross_attention_dim_t5, + hidden_size=cross_attention_dim_t5 * 4, + out_features=cross_attention_dim, + act_fn="silu_fp32", + ) + + self.text_embedding_padding = nn.Parameter( + torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) + ) + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + in_channels=in_channels, + embed_dim=hidden_size, + patch_size=patch_size, + pos_embed_type=None, + ) + + self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( + hidden_size, + pooled_projection_dim=pooled_projection_dim, + seq_len=text_len_t5, + cross_attention_dim=cross_attention_dim_t5, + ) + + # HunyuanDiT Blocks + self.blocks = nn.ModuleList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > num_layers // 2, + ) + for layer in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + return_dict=True, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of `BertModel`. + text_embedding_mask: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of `BertModel`. + encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. + text_embedding_mask_t5: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of T5 Text Encoder. + image_meta_size (torch.Tensor): + Conditional embedding indicate the image sizes + style: torch.Tensor: + Conditional embedding indicate the style + image_rotary_emb (`torch.Tensor`): + The image rotary embeddings to apply on query and key tensors during attention calculation. + return_dict: bool + Whether to return a dictionary. + """ + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) + + temb = self.time_extra_emb( + timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype + ) # [B, D] + + # text projection + batch_size, sequence_length, _ = encoder_hidden_states_t5.shape + encoder_hidden_states_t5 = self.text_embedder( + encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) + ) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) + + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) + text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) + text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() + + encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.config.num_layers // 2: + skip = skips.pop() + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + skip=skip, + ) # (N, L, D) + else: + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) # (N, L, D) + + if layer < (self.config.num_layers // 2 - 1): + skips.append(hidden_states) + + # final layer + hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) + hidden_states = self.proj_out(hidden_states) + # (N, L, patch_size ** 2 * out_channels) + + # unpatchify: (N, out_channels, H, W) + patch_size = self.pos_embed.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 36031e546a77..4987628cee18 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -150,6 +150,7 @@ "IFPipeline", "IFSuperResolutionPipeline", ] + _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -418,6 +419,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) + from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( KandinskyCombinedPipeline, diff --git a/src/diffusers/pipelines/hunyuandit/__init__.py b/src/diffusers/pipelines/hunyuandit/__init__.py new file mode 100644 index 000000000000..8337399106f0 --- /dev/null +++ b/src/diffusers/pipelines/hunyuandit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuandit"] = ["HunyuanDiTPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuandit import HunyuanDiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py new file mode 100644 index 000000000000..835a553c1245 --- /dev/null +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -0,0 +1,881 @@ +# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanDiTPipeline + + >>> pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT", torch_dtype=torch.float16) + >>> pipe.to("cuda") + + >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese + >>> # prompt = "An astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`MT5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + text_encoder_2=T5EncoderModel, + tokenizer_2=MT5Tokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = self.transformer.config.sample_size + + def encode_prompt( + self, + prompt: str, + device: torch.device, + dtype: torch.dtype, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`Tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index d3b79be1cb9c..e9a28de73d8f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -122,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanDiT2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index df436bc46c06..7c02091b19c5 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -212,6 +212,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanDiTPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class I2VGenXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/hunyuan_dit/__init__.py b/tests/pipelines/hunyuan_dit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py new file mode 100644 index 000000000000..ec445ad7bf6e --- /dev/null +++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py @@ -0,0 +1,266 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, BertModel, T5EncoderModel + +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + HunyuanDiT2DModel, + HunyuanDiTPipeline, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanDiTPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + required_optional_params = PipelineTesterMixin.required_optional_params + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanDiT2DModel( + sample_size=16, + num_layers=2, + patch_size=2, + attention_head_dim=8, + num_attention_heads=3, + in_channels=4, + cross_attention_dim=32, + cross_attention_dim_t5=32, + pooled_projection_dim=16, + hidden_size=24, + activation_fn="gelu-approximate", + ) + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = DDPMScheduler() + text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel") + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "use_resolution_binning": False, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 16, 16, 3)) + expected_slice = np.array( + [0.56939435, 0.34541583, 0.35915792, 0.46489206, 0.38775963, 0.45004836, 0.5957267, 0.59481275, 0.33287364] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_sequential_cpu_offload_forward_pass(self): + # TODO(YiYi) need to fix later + pass + + def test_sequential_offload_forward_pass_twice(self): + # TODO(YiYi) need to fix later + pass + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-3, + ) + + def test_save_load_optional_components(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0) + + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = pipe.encode_prompt( + prompt, + device=torch_device, + dtype=torch.float32, + text_encoder_index=1, + ) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": prompt_attention_mask, + "negative_prompt_embeds": negative_prompt_embeds, + "negative_prompt_attention_mask": negative_prompt_attention_mask, + "prompt_embeds_2": prompt_embeds_2, + "prompt_attention_mask_2": prompt_attention_mask_2, + "negative_prompt_embeds_2": negative_prompt_embeds_2, + "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + "use_resolution_binning": False, + } + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1e-4) + + +@slow +@require_torch_gpu +class HunyuanDiTPipelineIntegrationTests(unittest.TestCase): + prompt = "一个宇航员在骑马" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_hunyuan_dit_1024(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = HunyuanDiTPipeline.from_pretrained( + "XCLiu/HunyuanDiT-0523", revision="refs/pr/2", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + image = pipe( + prompt=prompt, height=1024, width=1024, generator=generator, num_inference_steps=2, output_type="np" + ).images + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array( + [0.48388672, 0.33789062, 0.30737305, 0.47875977, 0.25097656, 0.30029297, 0.4440918, 0.26953125, 0.30078125] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) + assert max_diff < 1e-3, f"Max diff is too high. got {image_slice.flatten()}"