Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions QEfficient/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def apply_head_blocking(
v: torch.FloatTensor,
head_block_size: int,
attention_mask: Optional[torch.FloatTensor] = None,
is_cross_attention: bool = False,
) -> torch.FloatTensor:
"""
Forward pass with head-only blocking (default mode).
Expand All @@ -67,8 +68,7 @@ def apply_head_blocking(
num_head_blocks = math.ceil(NH / head_block_size)

# Optimization: Handle small sequences with standard attention
BS, NH, K_CL, DH = k.shape
if K_CL <= 512:
if is_cross_attention:
scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
if attention_mask is not None:
scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
Expand Down Expand Up @@ -110,6 +110,7 @@ def apply_kv_blocking(
head_block_size: int,
num_kv_blocks: int,
attention_mask: Optional[torch.FloatTensor] = None,
is_cross_attention: bool = False,
) -> torch.FloatTensor:
"""
Forward pass with Key-Value blocking and head blocking.
Expand Down Expand Up @@ -137,8 +138,7 @@ def apply_kv_blocking(
block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)]

# Handle small sequences with standard attention
BS, NH, K_CL, DH = k.shape
if K_CL <= 512:
if is_cross_attention:
scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
if attention_mask is not None:
scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
Expand Down Expand Up @@ -214,6 +214,7 @@ def apply_q_blocking(
head_block_size: int,
num_q_blocks: int,
attention_mask: Optional[torch.FloatTensor] = None,
is_cross_attention: bool = False,
) -> torch.FloatTensor:
"""
Forward pass with Query blocking and head blocking.
Expand Down Expand Up @@ -241,8 +242,7 @@ def apply_q_blocking(
q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)]

# Handle small sequences with standard attention
BS, NH, K_CL, DH = k.shape
if K_CL <= 512:
if is_cross_attention:
scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
if attention_mask is not None:
scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
Expand Down Expand Up @@ -298,6 +298,7 @@ def apply_qkv_blocking(
num_kv_blocks: int,
num_q_blocks: int,
attention_mask: Optional[torch.FloatTensor] = None,
is_cross_attention: bool = False,
) -> torch.FloatTensor:
"""
Forward pass with combined Query, Key, Value blocking and head blocking.
Expand Down Expand Up @@ -328,8 +329,7 @@ def apply_qkv_blocking(
q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)]

# Optimization: Use standard attention for small sequences
BS, NH, K_CL, DH = k.shape
if K_CL <= 512:
if is_cross_attention:
scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor
if attention_mask is not None:
scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device))
Expand Down Expand Up @@ -429,6 +429,7 @@ def compute_blocked_attention(
num_q_blocks: int,
blocking_mode: str = "default",
attention_mask: Optional[torch.FloatTensor] = None,
is_cross_attention: bool = False,
) -> torch.FloatTensor:
"""
Main dispatcher function for different attention blocking strategies.
Expand All @@ -447,10 +448,12 @@ def compute_blocked_attention(
torch.FloatTensor: Attention output of shape (BS, NH, CL, DH)
"""
if blocking_mode == "kv":
return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask)
return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask, is_cross_attention)
elif blocking_mode == "q":
return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask)
return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask, is_cross_attention)
elif blocking_mode == "qkv":
return apply_qkv_blocking(q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask)
return apply_qkv_blocking(
q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask, is_cross_attention
)
else: # default
return apply_head_blocking(q, k, v, head_block_size, attention_mask)
return apply_head_blocking(q, k, v, head_block_size, attention_mask, is_cross_attention)
5 changes: 5 additions & 0 deletions QEfficient/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __call__(
Returns:
torch.Tensor: Processed hidden states after attention
"""
is_cross_attention = False
# encoder_hidden_states is only passed for cross-attention
if encoder_hidden_states is not None:
is_cross_attention = True
# Project inputs to query, key, value
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)

Expand Down Expand Up @@ -126,6 +130,7 @@ def apply_rotary_emb(
num_q_blocks,
blocking_mode=blocking_mode,
attention_mask=attention_mask,
is_cross_attention=is_cross_attention,
)

# Reshape back to original format
Expand Down
Loading