From 54c979765e39727f345731850bd96988b1948dcf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Aug 2024 15:51:12 +0530 Subject: [PATCH 1/4] start fusing flux. --- src/diffusers/models/attention_processor.py | 169 ++++++++++++++++++ .../models/transformers/transformer_flux.py | 108 ++++++++++- 2 files changed, 276 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ab1606b345..220bd9f67abb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1770,6 +1770,71 @@ def __call__( return hidden_states +class FusedFluxSingleAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, _, _ = hidden_states.shape + + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + # YiYi to-do: update uising apply_rotary_emb + # from ..embeddings import apply_rotary_emb + # query = apply_rotary_emb(query, image_rotary_emb) + # key = apply_rotary_emb(key, image_rotary_emb) + query, key = apply_rope(query, key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states + + class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -1868,6 +1933,110 @@ def __call__( return hidden_states, encoder_hidden_states +class FusedFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + # YiYi to-do: update uising apply_rotary_emb + # from ..embeddings import apply_rotary_emb + # query = apply_rotary_emb(query, image_rotary_emb) + # key = apply_rotary_emb(key, image_rotary_emb) + query, key = apply_rope(query, key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3168fd9a625f..6f3e0e5f1667 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -23,7 +23,14 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward -from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + FluxAttnProcessor2_0, + FluxSingleAttnProcessor2_0, + FusedFluxAttnProcessor2_0, + FusedFluxSingleAttnProcessor2_0, +) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers @@ -308,6 +315,105 @@ def __init__( self.gradient_checkpointing = False + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedFluxAttnProcessor2_0(), FusedFluxSingleAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value From b0544a1902e471126ed88b3316d7c12033f89c55 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Aug 2024 16:03:50 +0530 Subject: [PATCH 2/4] test --- tests/pipelines/flux/test_pipeline_flux.py | 47 +++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index b2744e3f0ad4..94b55f6a071a 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -13,7 +13,12 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) + @unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") @@ -143,6 +148,46 @@ def test_flux_prompt_embeds(self): max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + @slow @require_torch_gpu From 709690ad0fc3f9addd5041a1c206eecb7a93cc85 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 15 Aug 2024 16:32:28 +0530 Subject: [PATCH 3/4] finish fusion --- .../models/transformers/transformer_flux.py | 12 ++++-------- tests/pipelines/flux/test_pipeline_flux.py | 1 - 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 6f3e0e5f1667..e06c5cbf2dd7 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -386,19 +386,15 @@ def fuse_qkv_projections(self): """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedFluxAttnProcessor2_0(), FusedFluxSingleAttnProcessor2_0()) + if isinstance(module.processor, FluxAttnProcessor2_0): + module.set_processor(FusedFluxAttnProcessor2_0()) + else: + module.set_processor(FusedFluxSingleAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 94b55f6a071a..a594862fea45 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -20,7 +20,6 @@ ) - @unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxPipeline From 2bdfdde12b0a1f9a3b3af2a786035d01aad62a8e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Aug 2024 07:28:58 +0530 Subject: [PATCH 4/4] fix-copues --- src/diffusers/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 6e5cbfa8c540..fd0881a14880 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -341,7 +341,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)