From 38a88e52c01ecfacef12d817052538d190a2437f Mon Sep 17 00:00:00 2001 From: YifanShenSZ Date: Mon, 8 Dec 2025 14:27:05 -0800 Subject: [PATCH 1/3] torch.export does not support such vmap usage, so use simple causal mask construction --- src/transformers/masking_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 99306bd94c88..e44070f2d351 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -383,13 +383,9 @@ def sdpa_mask_recent_torch( if padding_mask is not None: mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from - # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it - # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices - with TransformGetItemToIndex(): - causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + q_indices = torch.arange(kv_length - q_length, kv_length) + k_indices = torch.arange(kv_length) + causal_mask = q_indices[:, None] >= k_indices[None, :] return causal_mask From d4669ff3e6ff3feef2aefe526bb803f616d4102e Mon Sep 17 00:00:00 2001 From: YifanShenSZ Date: Mon, 8 Dec 2025 15:49:45 -0800 Subject: [PATCH 2/3] models such as GPT-OSS require 4D mask --- src/transformers/masking_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index e44070f2d351..8dd9730f8ae0 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -386,6 +386,7 @@ def sdpa_mask_recent_torch( q_indices = torch.arange(kv_length - q_length, kv_length) k_indices = torch.arange(kv_length) causal_mask = q_indices[:, None] >= k_indices[None, :] + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) return causal_mask From d26311d768c36de11cce67b02124d0d74dd8fbb6 Mon Sep 17 00:00:00 2001 From: YifanShenSZ Date: Mon, 8 Dec 2025 15:50:07 -0800 Subject: [PATCH 3/3] use -1 in reshape to dodge Llama4 dynamic-shape torch.export failure --- src/transformers/models/llama4/modeling_llama4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index a974ed81ba2f..05e130b186ec 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -340,7 +340,7 @@ def forward( attn_scales = ( torch.log1p(torch.floor((cache_position.float() + 1.0) / self.floor_scale)) * self.attn_scale + 1.0 ) - attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 + attn_scales = attn_scales.view((1, -1, 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 query_states = (query_states * attn_scales).to(query_states.dtype) query_states = query_states.transpose(1, 2)