diff --git a/HunYuan/__init__.py b/HunYuan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/HunYuan/conf.py b/HunYuan/conf.py new file mode 100644 index 0000000..a405ddc --- /dev/null +++ b/HunYuan/conf.py @@ -0,0 +1,55 @@ +""" +List of all DiT model types / settings +""" +sampling_settings = { + "beta_schedule" : "linear", + "linear_start" : 0.00085, + "linear_end" : 0.03, + "timesteps" : 1000, + 'steps_offset': 1, + 'clip_sample': False, + 'clip_sample_range': 1.0, + 'beta_start': 0.00085, + 'beta_end': 0.03, + 'prediction_type': 'v_prediction', +} + +dit_conf = { + "DiT-g/2": { # DiT-g/2 + "unet_config": { + "depth" : 40, + "num_heads" : 16, + "patch_size" : 2, + "hidden_size" : 1408, + 'mlp_ratio': 4.3637, + }, + "sampling_settings" : sampling_settings, + }, + "DiT-XL/2": { # DiT_XL_2 + "unet_config": { + "depth" : 28, + "num_heads" : 16, + "patch_size" : 2, + "hidden_size" : 1152, + }, + "sampling_settings" : sampling_settings, + }, + "DiT-L/2": { # DiT_L_2 + "unet_config": { + "depth" : 24, + "num_heads" : 16, + "patch_size" : 2, + "hidden_size" : 1024, + }, + "sampling_settings" : sampling_settings, + }, + "DiT-B/2": { # DiT_B_2 + "unet_config": { + "depth" : 12, + "num_heads" : 12, + "patch_size" : 2, + "hidden_size" : 768, + }, + "sampling_settings" : sampling_settings, + }, +} diff --git a/HunYuan/loader.py b/HunYuan/loader.py new file mode 100644 index 0000000..901953f --- /dev/null +++ b/HunYuan/loader.py @@ -0,0 +1,93 @@ +import comfy.supported_models_base +import comfy.latent_formats +import comfy.model_patcher +import comfy.model_base +import comfy.utils +import torch +from comfy import model_management +from ..PixArt.diffusers_convert import convert_state_dict + +class EXM_DiT(comfy.supported_models_base.BASE): + unet_config = {} + unet_extra_config = {} + latent_format = comfy.latent_formats.SDXL + + def __init__(self, model_conf): + self.model_target = model_conf.get("target") + self.unet_config = model_conf.get("unet_config", {}) + self.sampling_settings = model_conf.get("sampling_settings", {}) + self.latent_format = self.latent_format() + # UNET is handled by extension + self.unet_config["disable_unet_model_creation"] = True + + def model_type(self, state_dict, prefix=""): + return comfy.model_base.ModelType.V_PREDICTION + +class EXM_Dit_Model(comfy.model_base.BaseModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + clip_prompt_embeds = kwargs.get("clip_prompt_embeds", None) + if clip_prompt_embeds is not None: + out["clip_prompt_embeds"] = comfy.conds.CONDRegular(torch.tensor(clip_prompt_embeds)) + + clip_attention_mask = kwargs.get("clip_attention_mask", None) + if clip_attention_mask is not None: + out["clip_attention_mask"] = comfy.conds.CONDRegular(torch.tensor(clip_attention_mask)) + + mt5_prompt_embeds = kwargs.get("mt5_prompt_embeds", None) + if mt5_prompt_embeds is not None: + out["mt5_prompt_embeds"] = comfy.conds.CONDRegular(torch.tensor(mt5_prompt_embeds)) + + mt5_attention_mask = kwargs.get("mt5_attention_mask", None) + if mt5_attention_mask is not None: + out["mt5_attention_mask"] = comfy.conds.CONDRegular(torch.tensor(mt5_attention_mask)) + + return out + +def load_dit(model_path, model_conf): + from comfy.diffusers_convert import convert_unet_state_dict + state_dict = comfy.utils.load_torch_file(model_path) + #state_dict=convert_unet_state_dict(state_dict) + #state_dict = state_dict.get("model", state_dict) + + parameters = comfy.utils.calculate_parameters(state_dict) + unet_dtype = torch.float16 #model_management.unet_dtype(model_params=parameters) + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.unet_offload_device() + + # ignore fp8/etc and use directly for now + #manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + #if manual_cast_dtype: + # print(f"DiT: falling back to {manual_cast_dtype}") + # unet_dtype = manual_cast_dtype + + #model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty + + model_conf = EXM_DiT(model_conf) + + model = EXM_Dit_Model( # same as comfy.model_base.BaseModel + model_conf, + model_type=comfy.model_base.ModelType.V_PREDICTION, + device=model_management.get_torch_device() + ) + + from .models.models import HunYuan + model.diffusion_model = HunYuan(**model_conf.unet_config) + model.latent_format = comfy.latent_formats.SDXL() + + model.diffusion_model.load_state_dict(state_dict) + model.diffusion_model.dtype = unet_dtype + model.diffusion_model.eval() + model.diffusion_model.to(unet_dtype) + + model_patcher = comfy.model_patcher.ModelPatcher( + model, + load_device = load_device, + offload_device = offload_device, + current_device = "cpu", + ) + return model_patcher diff --git a/HunYuan/models/__init__.py b/HunYuan/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/HunYuan/models/attn_layers.py b/HunYuan/models/attn_layers.py new file mode 100644 index 0000000..4308af9 --- /dev/null +++ b/HunYuan/models/attn_layers.py @@ -0,0 +1,377 @@ +import torch +import torch.nn as nn +from typing import Tuple, Union, Optional + +try: + import flash_attn + if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2: + from flash_attn.flash_attn_interface import flash_attn_kvpacked_func + from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention + else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention +except Exception as e: + print(f'flash_attn import failed: {e}') + + +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: Optional[torch.Tensor], + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + if xk is not None: + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + if xk is not None: + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +class FlashSelfMHAModified(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + dim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.dim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs) + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop) + self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // 2), RoPE for image + """ + b, s, d = x.shape + + qkv = self.Wqkv(x) + qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d] + q, k, v = qkv.unbind(dim=2) # [b, s, h, d] + q = self.q_norm(q).half() # [b, s, h, d] + k = self.k_norm(k).half() + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis_img) + assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' + q, k = qq, kk + + qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d] + context = self.inner_attn(qkv) + out = self.out_proj(context.view(b, s, d)) + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class FlashCrossMHAModified(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.scale = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop) + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim) + y: torch.Tensor + (batch, seqlen2, hidden_dim2) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // num_heads), RoPE for image + """ + b, s1, _ = x.shape # [b, s1, D] + _, s2, _ = y.shape # [b, s2, 1024] + + q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] + kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d] + k, v = kv.unbind(dim=2) # [b, s2, h, d] + q = self.q_norm(q).half() # [b, s1, h, d] + k = self.k_norm(k).half() # [b, s2, h, d] + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, _ = apply_rotary_emb(q, None, freqs_cis_img) + assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}' + q = qq # [b, s1, h, d] + kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d] + context = self.inner_attn(q, kv) # [b, s1, h, d] + context = context.view(b, s1, -1) # [b, s1, D] + + out = self.out_proj(context) + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class CrossAttention(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim) + y: torch.Tensor + (batch, seqlen2, hidden_dim2) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // 2), RoPE for image + """ + b, s1, c = x.shape # [b, s1, D] + _, s2, c = y.shape # [b, s2, 1024] + + q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] + kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d] + k, v = kv.unbind(dim=2) # [b, s, h, d] + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, _ = apply_rotary_emb(q, None, freqs_cis_img) + assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}' + q = qq + + q = q * self.scale + q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C + k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2 + attn = q @ k # attn -> B, H, L1, L2 + attn = attn.softmax(dim=-1) # attn -> B, H, L1, L2 + attn = self.attn_drop(attn) + x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C + context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C + + context = context.contiguous().view(b, s1, -1) + + out = self.out_proj(context) # context.reshape - B, L1, -1 + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class Attention(nn.Module): + """ + We rename some layer names to align with flash attention + """ + def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, 'dim should be divisible by num_heads' + self.head_dim = self.dim // num_heads + # This assertion is aligned with flash attention + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim ** -0.5 + + # qkv --> Wqkv + self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis_img=None): + B, N, C = x.shape + qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d] + q, k, v = qkv.unbind(0) # [b, h, s, d] + q = self.q_norm(q) # [b, h, s, d] + k = self.k_norm(k) # [b, h, s, d] + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True) + assert qq.shape == q.shape and kk.shape == k.shape, \ + f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' + q, k = qq, kk + + q = q * self.scale + attn = q @ k.transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s] + attn = attn.softmax(dim=-1) # [b, h, s, s] + attn = self.attn_drop(attn) + x = attn @ v # [b, h, s, d] + + x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d] + x = self.out_proj(x) + x = self.proj_drop(x) + + out_tuple = (x,) + + return out_tuple diff --git a/HunYuan/models/embedders.py b/HunYuan/models/embedders.py new file mode 100644 index 0000000..9fe08cb --- /dev/null +++ b/HunYuan/models/embedders.py @@ -0,0 +1,111 @@ +import math +import torch +import torch.nn as nn +from einops import repeat + +from timm.models.layers import to_2tuple + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, (tuple, list)) and len(img_size) == 2: + img_size = tuple(img_size) + else: + raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}") + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def update_image_size(self, img_size): + self.img_size = img_size + self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + def forward(self, x): + # B, C, H, W = x.shape + # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +def timestep_embedding(t, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(t, "b -> b d", d=dim) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/HunYuan/models/models.py b/HunYuan/models/models.py new file mode 100644 index 0000000..b2b6fa0 --- /dev/null +++ b/HunYuan/models/models.py @@ -0,0 +1,486 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import Mlp + +from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention +from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding +from .norm_layers import RMSNorm +from .poolers import AttentionPool +import folder_paths + + +class Resolution: + def __init__(self, width, height): + self.width = width + self.height = height + + def __str__(self): + return f'{self.height}x{self.width}' + + +class ResolutionGroup: + def __init__(self): + self.data = [ + Resolution(768, 768), # 1:1 + Resolution(1024, 1024), # 1:1 + Resolution(1280, 1280), # 1:1 + Resolution(1024, 768), # 4:3 + Resolution(1152, 864), # 4:3 + Resolution(1280, 960), # 4:3 + Resolution(768, 1024), # 3:4 + Resolution(864, 1152), # 3:4 + Resolution(960, 1280), # 3:4 + Resolution(1280, 768), # 16:9 + Resolution(768, 1280), # 9:16 + ] + self.supported_sizes = set([(r.width, r.height) for r in self.data]) + + def is_valid(self, width, height): + return (width, height) in self.supported_sizes + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class FP32_Layernorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), + self.eps).to(origin_dtype) + + +class FP32_SiLU(nn.SiLU): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class HunYuanDiTBlock(nn.Module): + """ + A HunYuanDiT block with `add` conditioning. + """ + def __init__(self, + hidden_size, + c_emb_size, + num_heads, + mlp_ratio=4.0, + text_states_dim=1024, + use_flash_attn=False, + qk_norm=False, + norm_type="layer", + skip=False, + ): + super().__init__() + self.use_flash_attn = use_flash_attn + use_ele_affine = True + + if norm_type == "layer": + norm_layer = FP32_Layernorm + elif norm_type == "rms": + norm_layer = RMSNorm + else: + raise ValueError(f"Unknown norm_type: {norm_type}") + + # ========================= Self-Attention ========================= + self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + if use_flash_attn: + self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + else: + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + + # ========================= FFN ========================= + self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + + # ========================= Add ========================= + # Simply use add like SDXL. + self.default_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, hidden_size, bias=True) + ) + + # ========================= Cross-Attention ========================= + if use_flash_attn: + self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + else: + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) + + # ========================= Skip Connection ========================= + if skip: + self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6) + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) + else: + self.skip_linear = None + + def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None): + # Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + # Self-Attention + shift_msa = self.default_modulation(c).unsqueeze(dim=1) + attn_inputs = ( + self.norm1(x) + shift_msa, freq_cis_img, + ) + x = x + self.attn1(*attn_inputs)[0] + + # Cross-Attention + cross_inputs = ( + self.norm3(x), text_states, freq_cis_img + ) + x = x + self.attn2(*cross_inputs)[0] + + # FFN Layer + mlp_inputs = self.norm2(x) + x = x + self.mlp(mlp_inputs) + + return x + + +class FinalLayer(nn.Module): + """ + The final layer of HunYuanDiT. + """ + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class HunYuan(nn.Module): + def __init__( + self, + input_size=(32, 32), + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + learn_sigma=True, + text_states_dim=1024, + text_states_dim_t5=2048, + text_len=77, + text_len_t5=256, + norm="layer", + infer_mode="torch", + use_fp16=True, + device="cuda", + **kwargs, + ): + super().__init__() + self.device = device + self.use_fp16=use_fp16 + self.dtype = torch.float16 + self.depth = depth + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size = hidden_size + self.text_states_dim = text_states_dim + self.text_states_dim_t5 = text_states_dim_t5 + self.text_len = text_len + self.text_len_t5 = text_len_t5 + self.norm = norm + self.head_size = self.hidden_size // self.num_heads + + use_flash_attn = infer_mode == 'fa' + qk_norm = True # See http://arxiv.org/abs/2302.05442 for details. + + self.mlp_t5 = nn.Sequential( + nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True), + FP32_SiLU(), + nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True), + ) + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32)) + + # Attention pooling + self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024) + + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, hidden_size) + + # Image size and crop size conditions + self.extra_in_dim = 256 * 6 + hidden_size + + # Text embedding for `add` + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.extra_in_dim += 1024 + self.extra_embedder = nn.Sequential( + nn.Linear(self.extra_in_dim, hidden_size * 4), + FP32_SiLU(), + nn.Linear(hidden_size * 4, hidden_size, bias=True), + ) + + # Image embedding + num_patches = self.x_embedder.num_patches + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + text_states_dim=self.text_states_dim, + use_flash_attn=use_flash_attn, + qk_norm=qk_norm, + norm_type=self.norm, + skip=layer > depth // 2, + ) + for layer in range(depth) + ]) + + self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels) + self.unpatchify_channels = self.out_channels + + self.initialize_weights() + + def extra_conds(self, **kwargs): + out = {} + + return out + + def calc_rope(self, height, width): + from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop + th = height // 8 // self.patch_size + tw = width // 8 // self.patch_size + base_size = 512 // 8 // self.patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + rope = get_2d_rotary_pos_embed(self.head_size, *sub_args) + return rope + + def standard_shapes(self): + resolutions = ResolutionGroup() + freqs_cis_img = {} + for reso in resolutions.data: + freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width) + return resolutions, freqs_cis_img + + def forward(self, x, timesteps, context, clip_prompt_embeds=None, clip_attention_mask=None,mt5_prompt_embeds=None, mt5_attention_mask=None, **kwargs): + #with torch.cuda.amp.autocast(): + context = context[:, 0] + + ## run original forward pass + out = self.forward_raw( + x = x.to(self.dtype), + t = timesteps.to(self.dtype), + y = context.to(torch.int), + encoder_hidden_states=clip_prompt_embeds.to(self.dtype), + text_embedding_mask=clip_attention_mask.to(self.dtype), + encoder_hidden_states_t5=mt5_prompt_embeds.to(self.dtype), + text_embedding_mask_t5=mt5_attention_mask.to(self.dtype), + ) + + ## only return EPS + out = out.to(torch.float16) + eps, rest = out[:, :self.in_channels], out[:, self.in_channels:] + return eps[:x.shape[0]] + + + def forward_raw(self, + x, + t, + y, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + cos_cis_img=None, + sin_cis_img=None, + return_dict=False, + ): + """ + Forward pass of the encoder. + + Parameters + ---------- + x: torch.Tensor + (B, D, H, W) + t: torch.Tensor + (B) + y: (N, 1, 120, C) tensor of class labels + encoder_hidden_states: torch.Tensor + CLIP text embedding, (B, L_clip, D) + text_embedding_mask: torch.Tensor + CLIP text embedding mask, (B, L_clip) + encoder_hidden_states_t5: torch.Tensor + T5 text embedding, (B, L_t5, D) + text_embedding_mask_t5: torch.Tensor + T5 text embedding mask, (B, L_t5) + image_meta_size: torch.Tensor + (B, 6) + style: torch.Tensor + (B) + cos_cis_img: torch.Tensor + sin_cis_img: torch.Tensor + return_dict: bool + Whether to return a dictionary. + """ + + ob, _, oh, ow = x.shape + batch_size=ob//2 + + text_states = encoder_hidden_states # 2,77,1024 + text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 + text_states_mask = text_embedding_mask.bool() # 2,77 + text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)) + text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024 + clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1) + + clip_t5_mask = clip_t5_mask + text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states)) + + th, tw = oh // self.patch_size, ow // self.patch_size + + # ========================= Build time and image embedding ========================= + #t=t.repeat(2) + #x=x.repeat(2,1,1,1) + t = self.t_embedder(t) + x = self.x_embedder(x) + #y = y.to(self.dtype) + + # Get image RoPE embedding according to `reso`lution. + freqs_cis_img = (cos_cis_img, sin_cis_img) + + # ========================= Concatenate all extra vectors ========================= + # Build text tokens with pooling + + extra_vec = self.pooler(encoder_hidden_states_t5) + + height=oh*8 + width=ow*8 + target_height = int((height // 16) * 16) + target_width = int((width // 16) * 16) + + # Build image meta size tokens + size_cond = list((1024,1024)) + [target_width, target_height, 0, 0] + image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device) + + image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256] + if self.use_fp16: + image_meta_size = image_meta_size.half() + image_meta_size = image_meta_size.view(-1, 6 * 256) + + extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + style = torch.as_tensor([0, 0] * batch_size, device=self.device) + + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + resolutions, freqs_cis_img = self.standard_shapes() + + reso = f'{target_height}x{target_width}' + if reso in freqs_cis_img: + freqs_cis_img = freqs_cis_img[reso] + else: + freqs_cis_img = self.calc_rope(target_height, target_width) + + c = t + self.extra_embedder(extra_vec) # [B, D] + + # ========================= Forward pass through HunYuanDiT blocks ========================= + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.depth // 2: + skip = skips.pop() + x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) + else: + x = block(x, c, text_states, freqs_cis_img) # (N, L, D) + + if layer < (self.depth // 2 - 1): + skips.append(x) + + # ========================= Final layer ========================= + x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels) + x = self.unpatchify(x, th, tw) # (N, out_channels, H, W) + + if return_dict: + return {'x': x} + return x + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.extra_embedder[0].weight, std=0.02) + nn.init.normal_(self.extra_embedder[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in HunYuanDiT blocks: + for block in self.blocks: + nn.init.constant_(block.default_modulation[-1].weight, 0) + nn.init.constant_(block.default_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + p = self.x_embedder.patch_size[0] + # h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + +################################################################################# +# HunYuanDiT Configs # +################################################################################# + +HUNYUAN_DIT_CONFIG = { + 'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637}, + 'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16}, + 'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16}, + 'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12}, +} diff --git a/HunYuan/models/norm_layers.py b/HunYuan/models/norm_layers.py new file mode 100644 index 0000000..5204ad9 --- /dev/null +++ b/HunYuan/models/norm_layers.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype) + + def forward(self, x): + y = super().forward(x).to(x.dtype) + return y + +def normalization(channels, dtype=None): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype) diff --git a/HunYuan/models/poolers.py b/HunYuan/models/poolers.py new file mode 100644 index 0000000..a4adcac --- /dev/null +++ b/HunYuan/models/poolers.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttentionPool(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) diff --git a/HunYuan/models/posemb_layers.py b/HunYuan/models/posemb_layers.py new file mode 100644 index 0000000..62c83df --- /dev/null +++ b/HunYuan/models/posemb_layers.py @@ -0,0 +1,225 @@ +import torch +import numpy as np +from typing import Union + + +def _to_tuple(x): + if isinstance(x, int): + return x, x + else: + return x + + +def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率 + th, tw = _to_tuple(tgt) + h, w = _to_tuple(src) + + tr = th / tw # base 分辨率 + r = h / w # 目标分辨率 + + # resize + if r > tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +def get_meshgrid(start, *args): + if len(args) == 0: + # start is grid_size + num = _to_tuple(start) + start = (0, 0) + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start) + stop = _to_tuple(args[0]) + num = (stop[0] - start[0], stop[1] - start[1]) + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start) # 左上角 eg: 12,0 + stop = _to_tuple(args[0]) # 右下角 eg: 20,32 + num = _to_tuple(args[1]) # 目标大小 eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 + grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + return grid + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = get_meshgrid(start, *args) # [2, H, w] + # grid_h = np.arange(grid_size, dtype=np.float32) + # grid_w = np.arange(grid_size, dtype=np.float32) + # grid = np.meshgrid(grid_w, grid_h) # here w goes first + # grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (W,H) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443 + +def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): + """ + This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure. + + Parameters + ---------- + embed_dim: int + embedding dimension size + start: int or tuple of int + If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; + If len(args) == 2, start is start, args[0] is stop, args[1] is num. + use_real: bool + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns + ------- + pos_embed: torch.Tensor + [HW, D/2] + """ + grid = get_meshgrid(start, *args) # [2, H, w] + grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2] + + """ + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + + +def calc_sizes(rope_img, patch_size, th, tw): + """ 计算 RoPE 的尺寸. """ + if rope_img == 'extend': + # 拓展模式 + sub_args = [(th, tw)] + elif rope_img.startswith('base'): + # 基于一个尺寸, 其他尺寸插值获得. + base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到 + start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角 + sub_args = [start, stop, (th, tw)] + else: + raise ValueError(f"Unknown rope_img: {rope_img}") + return sub_args + + +def init_image_posemb(rope_img, + resolutions, + patch_size, + hidden_size, + num_heads, + log_fn, + rope_real=True, + ): + freqs_cis_img = {} + for reso in resolutions: + th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size + sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角 + freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real) + log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) " + f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}") + return freqs_cis_img diff --git a/HunYuan/models/text_encoder.py b/HunYuan/models/text_encoder.py new file mode 100644 index 0000000..d07c773 --- /dev/null +++ b/HunYuan/models/text_encoder.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration + + +class MT5Embedder(nn.Module): + available_models = ["t5-v1_1-xxl"] + + def __init__( + self, + model_dir="t5-v1_1-xxl", + model_kwargs=None, + torch_dtype=None, + use_tokenizer_only=False, + conditional_generation=False, + max_length=128, + device="cuda", + ): + super().__init__() + self.device = device #"cuda" if torch.cuda.is_available() else "cpu" + self.torch_dtype = torch_dtype or torch.bfloat16 + self.max_length = max_length + if model_kwargs is None: + model_kwargs = { + # "low_cpu_mem_usage": True, + "torch_dtype": self.torch_dtype, + } + model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device} + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + if use_tokenizer_only: + return + if conditional_generation: + self.model = None + self.generation_model = T5ForConditionalGeneration.from_pretrained( + model_dir + ) + return + self.model = T5EncoderModel.from_pretrained(model_dir, **model_kwargs).eval().to(self.torch_dtype) + + def get_tokens_and_mask(self, texts): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + tokens = text_tokens_and_mask["input_ids"][0] + mask = text_tokens_and_mask["attention_mask"][0] + # tokens = torch.tensor(tokens).clone().detach() + # mask = torch.tensor(mask, dtype=torch.bool).clone().detach() + return tokens, mask + + def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + with torch.no_grad(): + outputs = self.model( + input_ids=text_tokens_and_mask["input_ids"].to(self.device), + attention_mask=text_tokens_and_mask["attention_mask"].to(self.device) + if attention_mask + else None, + output_hidden_states=True, + ) + text_encoder_embs = outputs["hidden_states"][layer_index].detach() + + return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device) + + @torch.no_grad() + def __call__(self, tokens, attention_mask, layer_index=-1): + with torch.cuda.amp.autocast(): + outputs = self.model( + input_ids=tokens, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + z = outputs.hidden_states[layer_index].detach() + return z + + def general(self, text: str): + # input_ids = input_ids = torch.tensor([list(text.encode("utf-8"))]) + num_special_tokens + input_ids = self.tokenizer(text, max_length=128).input_ids + print(input_ids) + outputs = self.generation_model(input_ids) + return outputs \ No newline at end of file diff --git a/HunYuan/nodes.py b/HunYuan/nodes.py new file mode 100644 index 0000000..8e084fc --- /dev/null +++ b/HunYuan/nodes.py @@ -0,0 +1,127 @@ +import os +import json +import torch +import folder_paths + +from .conf import dit_conf +from .loader import load_dit +from .models.text_encoder import MT5Embedder +from transformers import BertModel, BertTokenizer + +class MT5Loader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "HunyuanDiTfolder": (os.listdir(os.path.join(folder_paths.models_dir,"diffusers")), {"default": "HunyuanDiT"}), + "device": (["cpu", "cuda"], {"default": "cuda"}), + } + } + RETURN_TYPES = ("MT5","CLIP","Tokenizer",) + FUNCTION = "load_model" + CATEGORY = "ExtraModels/T5" + TITLE = "MT5 Loader" + + def load_model(self, HunyuanDiTfolder, device): + HunyuanDiTfolder=os.path.join(os.path.join(folder_paths.models_dir,"diffusers"),HunyuanDiTfolder) + mt5folder=os.path.join(HunyuanDiTfolder,"t2i/mt5") + clipfolder=os.path.join(HunyuanDiTfolder,"t2i/clip_text_encoder") + tokenizerfolder=os.path.join(HunyuanDiTfolder,"t2i/tokenizer") + torch_dtype=torch.float16 + if device=="cpu": + torch_dtype=torch.float32 + clip_text_encoder = BertModel.from_pretrained(str(clipfolder), False, revision=None).to(device) + tokenizer = BertTokenizer.from_pretrained(str(tokenizerfolder)) + embedder_t5 = MT5Embedder(mt5folder, torch_dtype=torch_dtype, max_length=256, device=device) + + return (embedder_t5,clip_text_encoder,tokenizer,) + +def clip_get_text_embeddings(clip_text_encoder,tokenizer,text,device): + max_length=tokenizer.model_max_length + text_inputs = tokenizer( + text, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = clip_text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + attention_mask = attention_mask.repeat(1, 1) + + return (prompt_embeds,attention_mask) + +class MT5TextEncode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "embedder_t5": ("MT5",), + "clip_text_encoder": ("CLIP",), + "tokenizer": ("Tokenizer",), + "prompt": ("STRING", {"multiline": True}), + "negative_prompt": ("STRING", {"multiline": True}), + } + } + + RETURN_TYPES = ("CONDITIONING","CONDITIONING",) + RETURN_NAMES = ("positive","negative",) + FUNCTION = "encode" + CATEGORY = "ExtraModels/T5" + TITLE = "MT5 Text Encode" + + def encode(self, embedder_t5, clip_text_encoder, tokenizer, prompt, negative_prompt): + print(f'prompt{prompt}') + clip_prompt_embeds,clip_attention_mask = clip_get_text_embeddings(clip_text_encoder,tokenizer,prompt,embedder_t5.device) + + clip_negative_prompt_embeds,clip_negative_attention_mask = clip_get_text_embeddings(clip_text_encoder,tokenizer,negative_prompt,embedder_t5.device) + + mt5_prompt_embeds,mt5_attention_mask = embedder_t5.get_text_embeddings(prompt) + + mt5_negative_prompt_embeds,mt5_negative_attention_mask = embedder_t5.get_text_embeddings(negative_prompt) + + return ([[clip_prompt_embeds, {"clip_prompt_embeds":clip_prompt_embeds,"clip_attention_mask":clip_attention_mask,"mt5_prompt_embeds":mt5_prompt_embeds,"mt5_attention_mask":mt5_attention_mask}]],[[clip_negative_prompt_embeds, {"clip_prompt_embeds":clip_negative_prompt_embeds,"clip_attention_mask":clip_negative_attention_mask,"mt5_prompt_embeds":mt5_negative_prompt_embeds,"mt5_attention_mask":mt5_negative_attention_mask}]], ) + +class HunYuanDitCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "HunyuanDiTfolder": (os.listdir(os.path.join(folder_paths.models_dir,"diffusers")), {"default": "HunyuanDiT"}), + "model": (list(dit_conf.keys()),), + "image_size_width": ("INT",{"default":1024}), + "image_size_height": ("INT",{"default":1024}), + # "num_classes": ("INT", {"default": 1000, "min": 0,}), + } + } + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_checkpoint" + CATEGORY = "ExtraModels/DiT" + TITLE = "HunYuanDitCheckpointLoader" + + def load_checkpoint(self, HunyuanDiTfolder, model, image_size_width, image_size_height): + image_size_width = int((image_size_width // 16) * 16) + image_size_height = int((image_size_height // 16) * 16) + HunyuanDiTfolder=os.path.join(os.path.join(folder_paths.models_dir,"diffusers"),HunyuanDiTfolder) + ckpt_path=os.path.join(HunyuanDiTfolder,"t2i/model/pytorch_model_ema.pt") + model_conf = dit_conf[model] + model_conf["unet_config"]["input_size"] = (image_size_height // 8, image_size_width // 8) + # model_conf["unet_config"]["num_classes"] = num_classes + dit = load_dit( + model_path = ckpt_path, + model_conf = model_conf, + ) + return (dit,) + +NODE_CLASS_MAPPINGS = { + "HunYuanDitCheckpointLoader" : HunYuanDitCheckpointLoader, + "MT5Loader" : MT5Loader, + "MT5TextEncode" : MT5TextEncode, +} diff --git a/HunYuan/wf.json b/HunYuan/wf.json new file mode 100644 index 0000000..4e59c93 --- /dev/null +++ b/HunYuan/wf.json @@ -0,0 +1,669 @@ +{ + "last_node_id": 50, + "last_link_id": 88, + "nodes": [ + { + "id": 32, + "type": "MT5Loader", + "pos": [ + -199, + 203 + ], + "size": { + "0": 315, + "1": 122 + }, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + { + "name": "MT5", + "type": "MT5", + "links": [ + 58 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 59 + ], + "shape": 3, + "slot_index": 1 + }, + { + "name": "Tokenizer", + "type": "Tokenizer", + "links": [ + 60 + ], + "shape": 3, + "slot_index": 2 + } + ], + "properties": { + "Node name for S&R": "MT5Loader" + }, + "widgets_values": [ + "HunyuanDiT", + "cuda" + ] + }, + { + "id": 35, + "type": "HunYuanDitCheckpointLoader", + "pos": [ + 430, + -221 + ], + "size": { + "0": 315, + "1": 130 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "image_size_width", + "type": "INT", + "link": 78, + "widget": { + "name": "image_size_width" + } + }, + { + "name": "image_size_height", + "type": "INT", + "link": 80, + "widget": { + "name": "image_size_height" + } + } + ], + "outputs": [ + { + "name": "model", + "type": "MODEL", + "links": [ + 77 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "HunYuanDitCheckpointLoader" + }, + "widgets_values": [ + "HunyuanDiT", + "DiT-g/2", + 1024, + 768 + ] + }, + { + "id": 46, + "type": "PrimitiveNode", + "pos": [ + 17, + -217 + ], + "size": { + "0": 210, + "1": 82 + }, + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "INT", + "type": "INT", + "links": [ + 78, + 83 + ], + "slot_index": 0, + "widget": { + "name": "image_size_width" + } + } + ], + "properties": { + "Run widget replace on values": false + }, + "widgets_values": [ + 1024, + "fixed" + ] + }, + { + "id": 47, + "type": "PrimitiveNode", + "pos": [ + 16, + -59 + ], + "size": { + "0": 210, + "1": 82 + }, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + { + "name": "INT", + "type": "INT", + "links": [ + 80, + 82 + ], + "slot_index": 0, + "widget": { + "name": "image_size_height" + } + } + ], + "properties": { + "Run widget replace on values": false + }, + "widgets_values": [ + 768, + "fixed" + ] + }, + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1223, + 195 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 7 + }, + { + "name": "vae", + "type": "VAE", + "link": 74 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 84 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEDecode" + } + }, + { + "id": 49, + "type": "LoadImage", + "pos": [ + 796.3945147047091, + 571.7084114574003 + ], + "size": { + "0": 315, + "1": 314.0000305175781 + }, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 85 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "ComfyUI_00061_.png", + "image" + ] + }, + { + "id": 22, + "type": "VAELoader", + "pos": [ + 243, + 571 + ], + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 4, + "mode": 0, + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 74, + 86 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "diffusion_pytorch_model.safetensors" + ] + }, + { + "id": 50, + "type": "VAEEncode", + "pos": [ + 1251.4522304568147, + 481.7803521272774 + ], + "size": { + "0": 210, + "1": 46 + }, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "pixels", + "type": "IMAGE", + "link": 85 + }, + { + "name": "vae", + "type": "VAE", + "link": 86 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "VAEEncode" + } + }, + { + "id": 48, + "type": "PreviewImage", + "pos": [ + 1648, + 193 + ], + "size": { + "0": 210, + "1": 246 + }, + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 84 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 3, + "type": "KSampler", + "pos": [ + 858, + 182 + ], + "size": { + "0": 315, + "1": 262 + }, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 77 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 65 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 66 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 88 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 7 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 50, + "fixed", + 20, + 6, + "ddim", + "ddim_uniform", + 1 + ] + }, + { + "id": 33, + "type": "MT5TextEncode", + "pos": [ + 208, + 204 + ], + "size": { + "0": 400, + "1": 200 + }, + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "embedder_t5", + "type": "MT5", + "link": 58 + }, + { + "name": "clip_text_encoder", + "type": "CLIP", + "link": 59 + }, + { + "name": "tokenizer", + "type": "Tokenizer", + "link": 60 + } + ], + "outputs": [ + { + "name": "positive", + "type": "CONDITIONING", + "links": [ + 65 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "negative", + "type": "CONDITIONING", + "links": [ + 66 + ], + "shape": 3, + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "MT5TextEncode" + }, + "widgets_values": [ + "一位年轻女子站在春季的火车站月台上。\n她身着蓝灰色长风衣,白色衬衫。她的深棕色头发扎成低马尾,几缕碎发随风飘扬。\n她的眼神充满期待,阳光洒在她温暖的脸庞上。", + "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺," + ] + }, + { + "id": 5, + "type": "EmptyLatentImage", + "pos": [ + 427, + -31 + ], + "size": { + "0": 315, + "1": 106 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "width", + "type": "INT", + "link": 83, + "widget": { + "name": "width" + } + }, + { + "name": "height", + "type": "INT", + "link": 82, + "widget": { + "name": "height" + } + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 88 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "EmptyLatentImage" + }, + "widgets_values": [ + 1024, + 768, + 2 + ] + } + ], + "links": [ + [ + 7, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 58, + 32, + 0, + 33, + 0, + "MT5" + ], + [ + 59, + 32, + 1, + 33, + 1, + "CLIP" + ], + [ + 60, + 32, + 2, + 33, + 2, + "Tokenizer" + ], + [ + 65, + 33, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 66, + 33, + 1, + 3, + 2, + "CONDITIONING" + ], + [ + 74, + 22, + 0, + 8, + 1, + "VAE" + ], + [ + 77, + 35, + 0, + 3, + 0, + "MODEL" + ], + [ + 78, + 46, + 0, + 35, + 0, + "INT" + ], + [ + 80, + 47, + 0, + 35, + 1, + "INT" + ], + [ + 82, + 47, + 0, + 5, + 1, + "INT" + ], + [ + 83, + 46, + 0, + 5, + 0, + "INT" + ], + [ + 84, + 8, + 0, + 48, + 0, + "IMAGE" + ], + [ + 85, + 49, + 0, + 50, + 0, + "IMAGE" + ], + [ + 86, + 22, + 0, + 50, + 1, + "VAE" + ], + [ + 88, + 5, + 0, + 3, + 3, + "LATENT" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.6303940863128614, + "offset": [ + 565.5430225631707, + 449.5915559841989 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/HunYuan/wf.png b/HunYuan/wf.png new file mode 100644 index 0000000..a3674e9 Binary files /dev/null and b/HunYuan/wf.png differ diff --git a/README.md b/README.md index 667aae0..5e687bf 100644 --- a/README.md +++ b/README.md @@ -1,201 +1,15 @@ -# Extra Models for ComfyUI +# ComfyUI HunyuanDiT (WIP) -This repository aims to add support for various different image diffusion models to ComfyUI. +[HunyuanDiT](https://github.com/Tencent/HunyuanDiT) -## Installation - -Simply clone this repo to your custom_nodes folder using the following command: - -`git clone https://github.com/city96/ComfyUI_ExtraModels custom_nodes/ComfyUI_ExtraModels` - -You will also have to install the requirements from the provided file by running `pip install -r requirements.txt` inside your VENV/conda env. If you downloaded the standalone version of ComfyUI, then follow the steps below. - -### Standalone ComfyUI - -I haven't tested this completely, so if you know what you're doing, use the regular venv/`git clone` install option when installing ComfyUI. - -Go to the where you unpacked `ComfyUI_windows_portable` to (where your run_nvidia_gpu.bat file is) and open a command line window. Press `CTRL+SHIFT+Right click` in an empty space and click "Open PowerShell window here". - -Clone the repository to your custom nodes folder, assuming haven't installed in through the manager. - -`git clone https://github.com/city96/ComfyUI_ExtraModels .\ComfyUI\custom_nodes\ComfyUI_ExtraModels` - -To install the requirements on windows, run these commands in the same window: -``` -.\python_embeded\python.exe -s -m pip install -r .\ComfyUI\custom_nodes\ComfyUI_ExtraModels\requirements.txt -.\python_embeded\python.exe -s -m pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui ``` - -To update, open the command line window like before and run the following commands: - -``` -cd .\ComfyUI\custom_nodes\ComfyUI_ExtraModels\ -git pull +huggingface-cli download --resume-download Tencent-Hunyuan/HunyuanDiT --local-dir ComfyUI/models/diffusers --local-dir-use-symlinks False ``` -Alternatively, use the manager, assuming it has an update function. - - - -## PixArt - -[Original Repo](https://github.com/PixArt-alpha/PixArt-alpha) - -### Model info / implementation -- Uses T5 text encoder instead of clip -- Available in 512 and 1024 versions, needs specific pre-defined resolutions to work correctly -- Same latent space as SD1.5 (works with the SD1.5 VAE) -- Attention needs optimization, images look worse without xformers. - -### Usage - -1. Download the model weights from the [PixArt alpha repo](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main) - you most likely want the 1024px one - `PixArt-XL-2-1024-MS.pth` -2. Place them in your checkpoints folder -3. Load them with the correct PixArt checkpoint loader -4. **Follow the T5v11 section of this readme** to set up the T5 text encoder - -> [!TIP] -> You should be able to use the model with the default KSampler if you're on the latest version of the node. -> In theory, this should allow you to use longer prompts as well as things like doing img2img. - -Limitations: -- `PixArt DPM Sampler` requires the negative prompt to be shorter than the positive prompt. -- `PixArt DPM Sampler` can only work with a batch size of 1. -- `PixArt T5 Text Encode` is from the reference implementation, therefore it doesn't support weights. `T5 Text Encode` support weights, but I can't attest to the correctness of the implementation. - -> [!IMPORTANT] -> Installing `xformers` is optional but strongly recommended as torch SDP is only partially implemented, if that. - -[Sample workflow here](https://github.com/city96/ComfyUI_ExtraModels/files/13617463/PixArtV3.json) - -![PixArtT12](https://github.com/city96/ComfyUI_ExtraModels/assets/125218114/eb1a02f9-6114-47eb-a066-261c39c55615) - -### PixArt Sigma - -The Sigma models work just like the normal ones. Out of the released checkpoints, the 512, 1024 and 2K one are supported. - -You can find the [1024 checkpoint here](https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS.pth). Place it in your models folder and **select the appropriate type in the model loader / resolution selection node.** - -> [!IMPORTANT] -> Make sure to select an SDXL VAE for PixArt Sigma! - -### PixArt LCM - -The LCM model also works if you're on the latest version. To use it: - -1. Download the [PixArt LCM model](https://huggingface.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS/blob/main/transformer/diffusion_pytorch_model.safetensors) and place it in your checkpoints folder. -2. Add a `ModelSamplingDiscrete` node and set "sampling" to "lcm" -3. Adjust the KSampler settings - Set the sampler to "lcm". Your CFG should be fairly low (1.1-1.5), your steps should be around 5. - -Everything else can be the same the same as in the example above. - -![PixArtLCM](https://github.com/city96/ComfyUI_ExtraModels/assets/125218114/558f8b30-449b-4973-ad7e-6aa69832adcb) - - - -## DiT - -[Original Repo](https://github.com/facebookresearch/DiT) - -### Model info / implementation -- Uses class labels instead of prompts -- Limited to 256x256 or 512x512 images -- Same latent space as SD1.5 (works with the SD1.5 VAE) -- Works in FP16, but no other optimization - -### Usage - -1. Download the original model weights from the [DiT Repo](https://github.com/facebookresearch/DiT) or the converted [FP16 safetensor ones from Huggingface](https://huggingface.co/city96/DiT/tree/main). -2. Place them in your checkpoints folder. (You may need to move them if you had them in `ComfyUI\models\dit` before) -3. Load the model and select the class labels as shown in the image below -4. **Make sure to use the Empty label conditioning for the Negative input of the KSampler!** - -ConditioningCombine nodes *should* work for combining multiple labels. The area ones don't since the model currently can't handle dynamic input dimensions. - -[Sample workflow here](https://github.com/city96/ComfyUI_ExtraModels/files/13619259/DiTV2.json) - -![DIT_WORKFLOW_IMG](https://github.com/city96/ComfyUI_ExtraModels/assets/125218114/cdd4ec94-b0eb-436a-bf23-a3bcef8d7b90) - - - -## T5 - -### T5v11 - -The model files can be downloaded from the [DeepFloyd/t5-v1_1-xxl](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) repository. - -You will need to download the following 4 files: - - `config.json` - - `pytorch_model-00001-of-00002.bin` - - `pytorch_model-00002-of-00002.bin` - - `pytorch_model.bin.index.json` - -Place them in your `ComfyUI/models/t5` folder. You can put them in a subfolder called "t5-v1.1-xxl" though it doesn't matter. There are int8 safetensor files in the other DeepFloyd repo, thought they didn't work for me. - -For faster loading/smaller file sizes, you may pick one of the following alternative downloads: -- [FP16 converted version](https://huggingface.co/theunlikely/t5-v1_1-xxl-fp16/tree/main) - Same layout as the original, download both safetensor files as well as the `*.index.json` and `config.json` files. -- [BF16 converter version](https://huggingface.co/city96/t5-v1_1-xxl-encoder-bf16/tree/main) - Merged into a single safetensor, only `model.safetensors` (+`config.json` for folder mode) are reqired. - -To move T5 to a different drive/folder, do the same as you would when moving checkpoints, but add ` t5: t5` to `extra_model_paths.yaml` and create a directory called "t5" in the alternate path specified in the `base_path` variable. - -### Usage - -Loaded onto the CPU, it'll use about 22GBs of system RAM. Depending on which weights you use, it might use slightly more during loading. - -If you have a second GPU, selecting "cuda:1" as the device will allow you to use it for T5, freeing at least some VRAM/System RAM. Using FP16 as the dtype is recommended. - -Loaded in bnb4bit mode, it only takes around 6GB VRAM, making it work with 12GB cards. The only drawback is that it'll constantly stay in VRAM since BitsAndBytes doesn't allow moving the weights to the system RAM temporarily. Switching to a different workflow *should* still release the VRAM as expected. Pascal cards (1080ti, P40) seem to struggle with 4bit. Select "cpu" if you encounter issues. - -On windows, you may need a newer version of bitsandbytes for 4bit. Try `python -m pip install bitsandbytes` - -> [!IMPORTANT] -> You may also need to upgrade transformers and install spiece for the tokenizer. `pip install -r requirements.txt` - - - -## VAE - -A few custom VAE models are supported. The option to select a different dtype when loading is also possible, which can be useful for testing/comparisons. You can load the models listed below using the "ExtraVAELoader" node. - -**Models like PixArt/DiT do NOT need a special VAE. Unless mentioned, use one of the following as you would with any other model:** -- [VAE for SD1.X, DiT and PixArt alpha](https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors). -- [VAE for SDXL and PixArt sigma](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/diffusion_pytorch_model.safetensors) - -### Consistency Decoder - -[Original Repo](https://github.com/openai/consistencydecoder) - -This now works thanks to the work of @mrsteyk and @madebyollin - [Gist with more info](https://gist.github.com/madebyollin/865fa6a18d9099351ddbdfbe7299ccbf). - -- Download the converted safetensor VAE from [this HF repository](https://huggingface.co/mrsteyk/consistency-decoder-sd15/blob/main/stk_consistency_decoder_amalgamated.safetensors). If you downloaded the OpenAI model before, it won't work, as it is a TorchScript file. Feel free to delete it. -- Put the file in your VAE folder -- Load it with the ExtraVAELoader -- Set it to fp16 or bf16 to not run out of VRAM -- Use tiled VAE decode if required - -### Deflickering Decoder / VideoDecoder - -This is the VAE that comes baked into the [Stable Video Diffusion](https://stability.ai/news/stable-video-diffusion-open-ai-video-model) model. - -It doesn't seem particularly good as a normal VAE (color issues, pretty bad with finer details). - -Still for completeness sake the code to run it is mostly implemented. To obtain the weights just extract them from the sdv model: - -```py -from safetensors.torch import load_file, save_file - -pf = "first_stage_model." # Key prefix -sd = load_file("svd_xt.safetensors") -vae = {k.replace(pf, ''):v for k,v in sd.items() if k.startswith(pf)} -save_file(vae, "svd_xt_vae.safetensors") -``` - -### AutoencoderKL / VQModel - -`kl-f4/8/16/32` from the [compvis/latent diffusion repo](https://github.com/CompVis/latent-diffusion/tree/main#pretrained-autoencoding-models). +sdxl vae -`vq-f4/8/16` from the taming transformers repo, weights for both vq and kl models available [here](https://ommer-lab.com/files/latent-diffusion/) +## workflow -`vq-f8` can accepts latents from the SD unet but just like xl with v1 latents, output largely garbage. The rest are completely useless without a matching UNET that uses the correct channel count. +[Recommended complete Workflow](https://github.com/chaojie/ComfyUI_ExtraModels/blob/main/HunYuan/wf.json) -![VAE_TEST](https://github.com/city96/ComfyUI_ExtraModels/assets/125218114/316c7029-ee78-4ff7-a46a-b56ef91477eb) + \ No newline at end of file diff --git a/__init__.py b/__init__.py index 38967a2..fd679e3 100644 --- a/__init__.py +++ b/__init__.py @@ -10,6 +10,10 @@ # from .DeciDiffusion.nodes import NODE_CLASS_MAPPINGS as DeciDiffusion_Nodes # NODE_CLASS_MAPPINGS.update(DeciDiffusion_Nodes) + # HunYuan + from .HunYuan.nodes import NODE_CLASS_MAPPINGS as HunYuan_Nodes + NODE_CLASS_MAPPINGS.update(HunYuan_Nodes) + # DiT from .DiT.nodes import NODE_CLASS_MAPPINGS as DiT_Nodes NODE_CLASS_MAPPINGS.update(DiT_Nodes) diff --git a/requirements.txt b/requirements.txt index 51c2340..258d091 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ timm==0.6.13 sentencepiece>=0.1.97 transformers>=4.34.1 accelerate>=0.23.0 +einops \ No newline at end of file