diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 99306bd94c88..8dd9730f8ae0 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -383,13 +383,10 @@ def sdpa_mask_recent_torch( if padding_mask is not None: mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from - # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it - # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices - with TransformGetItemToIndex(): - causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + q_indices = torch.arange(kv_length - q_length, kv_length) + k_indices = torch.arange(kv_length) + causal_mask = q_indices[:, None] >= k_indices[None, :] + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) return causal_mask diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index a974ed81ba2f..05e130b186ec 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -340,7 +340,7 @@ def forward( attn_scales = ( torch.log1p(torch.floor((cache_position.float() + 1.0) / self.floor_scale)) * self.attn_scale + 1.0 ) - attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 + attn_scales = attn_scales.view((1, -1, 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 query_states = (query_states * attn_scales).to(query_states.dtype) query_states = query_states.transpose(1, 2)