diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 322bbfaf9..114bcb3bd 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -55,16 +55,11 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section): returns: x_t: (bs, seq_len, head_dim // 2) """ - freqs_t = freqs[0] # just overwrite the first dimension T - half_shape = freqs.shape[-1] // 2 + freqs_t = freqs[0].clone() for dim, offset in enumerate((1, 2), start=1): # H, W length = mrope_section[dim] * 3 idx = slice(offset, length, 3) freqs_t[..., idx] = freqs[dim, ..., idx] - offset += half_shape - length += half_shape - idx = slice(offset, length, 3) - freqs_t[..., idx] = freqs[dim, ..., idx] return freqs_t @@ -100,8 +95,12 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids] - sin = sin[position_ids] + # Safe gather: map padded -1 IDs to 0 for gather, then zero them out after interleave. + invalid_pos_mask = position_ids < 0 + safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids) + flat_pos = safe_position_ids.reshape(-1) + cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1]) + sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1]) cos = qeff_apply_interleaved_mrope(cos, mrope_section) sin = qeff_apply_interleaved_mrope(sin, mrope_section) cos = cos.unsqueeze(unsqueeze_dim)