diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ab1606b345..fc225567ddc1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1695,81 +1695,6 @@ def __call__( return hidden_states -# YiYi to-do: refactor rope related functions/classes -def apply_rope(xq, xk, freqs_cis): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - -class FluxSingleAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states - - class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -1785,16 +1710,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) @@ -1813,59 +1729,58 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - return hidden_states, encoder_hidden_states + return hidden_states, encoder_hidden_states + else: + return hidden_states class XFormersAttnAddedKVProcessor: @@ -4105,6 +4020,17 @@ def __init__(self): pass +class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." + deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message) + super().__init__() + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index ba4933dcad67..b29930f81ea2 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -24,9 +24,9 @@ from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from .controlnet import BaseOutput, zero_module -from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings +from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from .modeling_outputs import Transformer2DModelOutput -from .transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock +from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -59,7 +59,7 @@ def __init__( self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim - self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) @@ -272,8 +272,20 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) - ids = torch.cat((txt_ids, img_ids), dim=1) + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) block_samples = () diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1258964385da..b2f496833176 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed( linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed( repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ @@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed( if isinstance(pos, int): pos = np.arange(pos) theta = theta * ntk_factor - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] - freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: - freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] return freqs_cis @@ -540,6 +543,31 @@ def apply_rotary_emb( return x_out.type_as(x) +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.squeeze().float().cpu().numpy() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + class TimestepEmbedding(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 0b7106cae442..3f28f7d134ec 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -23,52 +23,18 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward -from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 +from ...models.attention_processor import Attention, FluxAttnProcessor2_0 from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings +from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# YiYi to-do: refactor rope related functions/classes -def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: - assert dim % 2 == 0, "The dimension must be even." - - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim - omega = 1.0 / (theta**scale) - - batch_size, seq_length = pos.shape - out = torch.einsum("...n,d->...nd", pos, omega) - cos_out = torch.cos(out) - sin_out = torch.sin(out) - - stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) - return out.float() - - -# YiYi to-do: refactor rope related functions/classes -class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: List[int]): - super().__init__() - self.dim = dim - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - emb = torch.cat( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], - dim=-3, - ) - return emb.unsqueeze(1) - - @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): r""" @@ -93,7 +59,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - processor = FluxSingleAttnProcessor2_0() + processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -265,13 +231,14 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], + axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) @@ -381,8 +348,19 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) - ids = torch.cat((txt_ids, img_ids), dim=1) + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 1ec6d6bdb0b1..02458a5b4881 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -331,10 +331,6 @@ def encode_prompt( scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt @@ -364,8 +360,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -425,9 +420,8 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @@ -724,7 +718,6 @@ def __call__( noise_pred = self.transformer( hidden_states=latents, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 84450374cb30..b9e93e720baf 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -354,10 +354,6 @@ def encode_prompt( scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt @@ -387,8 +383,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -449,9 +444,8 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @@ -804,7 +798,6 @@ def __call__( noise_pred = self.transformer( hidden_states=latents, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 64722e2d9797..0ce01fb93f40 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -976,7 +976,6 @@ def test_sharded_checkpoints_device_map(self): self.assertTrue(actual_num_shards == expected_num_shards) new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto") - new_model = new_model.to(torch_device) torch.manual_seed(0) if "generator" in inputs_dict: diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index bda37621c27d..538d158cbcb9 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -44,8 +44,8 @@ def dummy_input(self): hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) - text_ids = torch.randn((batch_size, sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((batch_size, height * width, num_image_channels)).to(torch_device) + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) return { @@ -80,3 +80,31 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_deprecated_inputs_img_txt_ids_3d(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output_1 = model(**inputs_dict).to_tuple()[0] + + # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) + text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) + image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + + assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" + assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + + inputs_dict["txt_ids"] = text_ids_3d + inputs_dict["img_ids"] = image_ids_3d + + with torch.no_grad(): + output_2 = model(**inputs_dict).to_tuple()[0] + + self.assertEqual(output_1.shape, output_2.shape) + self.assertTrue( + torch.allclose(output_1, output_2, atol=1e-5), + msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + )