diff --git a/QEfficient/diffusers/models/modeling_utils.py b/QEfficient/diffusers/models/modeling_utils.py index 59727be2de..799acbffdc 100644 --- a/QEfficient/diffusers/models/modeling_utils.py +++ b/QEfficient/diffusers/models/modeling_utils.py @@ -42,6 +42,7 @@ def apply_head_blocking( v: torch.FloatTensor, head_block_size: int, attention_mask: Optional[torch.FloatTensor] = None, + is_cross_attention: bool = False, ) -> torch.FloatTensor: """ Forward pass with head-only blocking (default mode). @@ -67,8 +68,7 @@ def apply_head_blocking( num_head_blocks = math.ceil(NH / head_block_size) # Optimization: Handle small sequences with standard attention - BS, NH, K_CL, DH = k.shape - if K_CL <= 512: + if is_cross_attention: scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor if attention_mask is not None: scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) @@ -110,6 +110,7 @@ def apply_kv_blocking( head_block_size: int, num_kv_blocks: int, attention_mask: Optional[torch.FloatTensor] = None, + is_cross_attention: bool = False, ) -> torch.FloatTensor: """ Forward pass with Key-Value blocking and head blocking. @@ -137,8 +138,7 @@ def apply_kv_blocking( block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)] # Handle small sequences with standard attention - BS, NH, K_CL, DH = k.shape - if K_CL <= 512: + if is_cross_attention: scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor if attention_mask is not None: scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) @@ -214,6 +214,7 @@ def apply_q_blocking( head_block_size: int, num_q_blocks: int, attention_mask: Optional[torch.FloatTensor] = None, + is_cross_attention: bool = False, ) -> torch.FloatTensor: """ Forward pass with Query blocking and head blocking. @@ -241,8 +242,7 @@ def apply_q_blocking( q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)] # Handle small sequences with standard attention - BS, NH, K_CL, DH = k.shape - if K_CL <= 512: + if is_cross_attention: scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor if attention_mask is not None: scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) @@ -298,6 +298,7 @@ def apply_qkv_blocking( num_kv_blocks: int, num_q_blocks: int, attention_mask: Optional[torch.FloatTensor] = None, + is_cross_attention: bool = False, ) -> torch.FloatTensor: """ Forward pass with combined Query, Key, Value blocking and head blocking. @@ -328,8 +329,7 @@ def apply_qkv_blocking( q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)] # Optimization: Use standard attention for small sequences - BS, NH, K_CL, DH = k.shape - if K_CL <= 512: + if is_cross_attention: scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor if attention_mask is not None: scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) @@ -429,6 +429,7 @@ def compute_blocked_attention( num_q_blocks: int, blocking_mode: str = "default", attention_mask: Optional[torch.FloatTensor] = None, + is_cross_attention: bool = False, ) -> torch.FloatTensor: """ Main dispatcher function for different attention blocking strategies. @@ -447,10 +448,12 @@ def compute_blocked_attention( torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) """ if blocking_mode == "kv": - return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask) + return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask, is_cross_attention) elif blocking_mode == "q": - return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask) + return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask, is_cross_attention) elif blocking_mode == "qkv": - return apply_qkv_blocking(q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask) + return apply_qkv_blocking( + q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask, is_cross_attention + ) else: # default - return apply_head_blocking(q, k, v, head_block_size, attention_mask) + return apply_head_blocking(q, k, v, head_block_size, attention_mask, is_cross_attention) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index 9200997d71..930c5e90a3 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -79,6 +79,10 @@ def __call__( Returns: torch.Tensor: Processed hidden states after attention """ + is_cross_attention = False + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is not None: + is_cross_attention = True # Project inputs to query, key, value query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) @@ -126,6 +130,7 @@ def apply_rotary_emb( num_q_blocks, blocking_mode=blocking_mode, attention_mask=attention_mask, + is_cross_attention=is_cross_attention, ) # Reshape back to original format