From 40ad9c806c1e9ab8bf511a38cca876870a2d4289 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Jul 2024 07:01:09 +0530 Subject: [PATCH 1/6] add fusion support to pixart --- .../transformers/pixart_transformer_2d.py | 103 +++++++++++++++++- 1 file changed, 102 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 9c8f9b09083f..821c9a18c2f1 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import torch from torch import nn @@ -19,6 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ..attention import BasicTransformerBlock +from ..attention_processor import Attention, FusedAttnProcessor2_0, AttentionProcessor from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -186,6 +187,106 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, hidden_states: torch.Tensor, From 48dfbb6f00bb0864acf91d7265c0fae1d3ea486a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Jul 2024 07:12:33 +0530 Subject: [PATCH 2/6] add to auraflow. --- src/diffusers/models/attention_processor.py | 96 +++++++++++++++++ .../transformers/auraflow_transformer_2d.py | 102 +++++++++++++++++- 2 files changed, 197 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 87bb9e07dc75..a1ef640537db 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1262,6 +1262,102 @@ def __call__( return hidden_states +class FusedAuraFlowAttnProcessor2_0: + """Attention processor used typically in processing Aura Flow with fused projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): + raise ImportError( + "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + # Reshape. + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # Apply QK norm. + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Concatenate the projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) + + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Attention. + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, encoder_hidden_states.shape[1] :], + hidden_states[:, : encoder_hidden_states.shape[1]], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if encoder_hidden_states is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 342373b4c11d..7de7821bf09c 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_processor import Attention, AuraFlowAttnProcessor2_0 +from ..attention_processor import Attention, AuraFlowAttnProcessor2_0, FusedAuraFlowAttnProcessor2_0, AttentionProcessor from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -320,6 +320,106 @@ def __init__( self.gradient_checkpointing = False + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAuraFlowAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value From 2ff7a4a3dad03f97ae6b7db9e70e072397c0ece0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Jul 2024 07:21:07 +0530 Subject: [PATCH 3/6] add tests --- src/diffusers/models/attention_processor.py | 15 ++++-- .../transformers/auraflow_transformer_2d.py | 7 ++- .../transformers/pixart_transformer_2d.py | 2 +- .../aura_flow/test_pipeline_aura_dlow.py | 46 +++++++++++++++++- tests/pipelines/pixart_sigma/test_pixart.py | 47 ++++++++++++++++++- 5 files changed, 108 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a1ef640537db..fe0d36c61fbf 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -219,6 +219,7 @@ def __init__( self.to_v = None if self.added_kv_proj_dim is not None: + self.added_proj_bias = added_proj_bias self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: @@ -685,12 +686,15 @@ def fuse_projections(self, fuse=True): in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] - self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype) - self.to_added_qkv.weight.copy_(concatenated_weights) - concatenated_bias = torch.cat( - [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype ) - self.to_added_qkv.bias.copy_(concatenated_bias) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) self.fused_projections = fuse @@ -1358,6 +1362,7 @@ def __call__( else: return hidden_states + class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 7de7821bf09c..1f709805089d 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -22,7 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_processor import Attention, AuraFlowAttnProcessor2_0, FusedAuraFlowAttnProcessor2_0, AttentionProcessor +from ..attention_processor import ( + Attention, + AttentionProcessor, + AuraFlowAttnProcessor2_0, + FusedAuraFlowAttnProcessor2_0, +) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 821c9a18c2f1..5c9c61243c07 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -19,7 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ..attention import BasicTransformerBlock -from ..attention_processor import Attention, FusedAttnProcessor2_0, AttentionProcessor +from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0 from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py index 9a2f1846f1d9..3694a733163c 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_dlow.py @@ -9,7 +9,11 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -119,3 +123,43 @@ def test_attention_slicing_forward_pass(self): # Attention slicing needs to implemented differently for this because how single DiT and MMDiT # blocks interfere with each other. return + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index a4cc60d12588..a92e99366ee3 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -36,7 +36,12 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + to_np, +) enable_full_determinism() @@ -308,6 +313,46 @@ def test_inference_with_multiple_images_per_prompt(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + @slow @require_torch_gpu From 62f2af1be86f9af8c51cd03fcbe9bacfa4c5a4d8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 26 Jul 2024 07:22:38 +0530 Subject: [PATCH 4/6] apply review feedback. --- src/diffusers/models/attention_processor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index afacd1a59720..cbf2e53c0029 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -218,8 +218,8 @@ def __init__( self.to_k = None self.to_v = None + self.added_proj_bias = added_proj_bias if self.added_kv_proj_dim is not None: - self.added_proj_bias = added_proj_bias self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: @@ -1279,8 +1279,6 @@ def __call__( attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, - *args, - **kwargs, ) -> torch.FloatTensor: batch_size = hidden_states.shape[0] From 33310efcdfbf24d8c168834811d78359ede79681 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 26 Jul 2024 07:44:00 +0530 Subject: [PATCH 5/6] add back args and kwargs --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cbf2e53c0029..3a85ee0c2ba9 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1279,6 +1279,8 @@ def __call__( attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, + *args, + **kwargs ) -> torch.FloatTensor: batch_size = hidden_states.shape[0] From 0e5036db0e2f444ecbde1f2c7080682dd902e81f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 26 Jul 2024 07:44:21 +0530 Subject: [PATCH 6/6] style --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3a85ee0c2ba9..bb3ebec36a68 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1280,7 +1280,7 @@ def __call__( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, *args, - **kwargs + **kwargs, ) -> torch.FloatTensor: batch_size = hidden_states.shape[0]