padding-free / packed-sequence support for Qwen3.5.#186
padding-free / packed-sequence support for Qwen3.5.#186meichangsu1 wants to merge 2 commits intomodelscope:mainfrom
Conversation
The `is_packed` flag was ambiguous and only inferred from position IDs. Now `padding_free` is explicitly passed as input, making the intent clearer and enabling early validation of attention backend compatibility.
Simplify the logic for returning logits in `forward` and `forward_only` methods by removing redundant `_outputs` copy and `logits` variable. The new logic directly modifies `outputs` and creates a single copy for return, reducing code complexity and potential bugs.
There was a problem hiding this comment.
Code Review
This pull request introduces support for padding-free and packed sequence inputs for Qwen 3.5 models, specifically targeting GatedDeltaNet and linear attention within a sequence parallel context. Key changes include a new patching mechanism for Qwen 3.5, refactored attention logic to handle variable sequence lengths without padding, and fallback implementations for linear attention kernels when specialized libraries are missing. Feedback highlights a regression in how sequence boundaries are determined in the attention strategy and identifies inconsistencies in the return types and activation handling within the new torch-based fallback for causal convolution.
| position_ids = kwargs.get('position_ids') | ||
| if position_ids is None: | ||
| position_ids = self.real_position_ids | ||
| if position_ids is None: | ||
| raise ValueError('Packed/varlen flash_attention_2 requires position_ids to derive ' | ||
| 'cu_seq_lens_q.') | ||
| position_ids = self._extract_real_position_ids(position_ids) | ||
| position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids) | ||
| cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to( | ||
| dtype=torch.int32, | ||
| device=query.device, | ||
| ) |
There was a problem hiding this comment.
The logic for obtaining cu_seqlens has been changed to always derive it from position_ids, ignoring any cu_seq_lens_q provided in kwargs. This is a regression in flexibility and potentially correctness. If cu_seq_lens_q is provided by the processor, it likely describes the real sequence boundaries. Deriving it from the SP-padded position_ids (line 231) will include the padding tokens in the last sequence segment, causing Flash Attention to process them unnecessarily. It is recommended to restore the check for cu_seq_lens_q in kwargs first.
cu_seqlens = kwargs.get('cu_seq_lens_q')
if cu_seqlens is None:
position_ids = kwargs.get('position_ids')
if position_ids is None:
position_ids = self.real_position_ids
if position_ids is None:
raise ValueError('Packed/varlen flash_attention_2 requires position_ids to derive '
'cu_seq_lens_q.')
position_ids = self._extract_real_position_ids(position_ids)
position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids)
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(
dtype=torch.int32,
device=query.device,
)
else:
cu_seqlens = cu_seqlens.to(dtype=torch.int32, device=query.device)| def _torch_causal_conv1d_fn( | ||
| *, | ||
| x, | ||
| weight, | ||
| bias=None, | ||
| activation=None, | ||
| seq_idx=None, | ||
| backend=None, | ||
| cu_seqlens=None, | ||
| ): | ||
| # Fallback priority: | ||
| # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above. | ||
| # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable. | ||
| # 3. plain torch conv1d is the final non-packed fallback. | ||
| del backend | ||
| if cu_seqlens is not None: | ||
| raise NotImplementedError( | ||
| 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' | ||
| 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' | ||
| 'Please install flash-linear-attention or disable padding_free/packing.') | ||
| if origin_causal_conv1d_fn is not None: | ||
| out = origin_causal_conv1d_fn( | ||
| x=x.transpose(1, 2).contiguous(), | ||
| weight=weight, | ||
| bias=bias, | ||
| activation=activation, | ||
| seq_idx=seq_idx, | ||
| ) | ||
| return out.transpose(1, 2).contiguous() | ||
| seq_len = x.shape[1] | ||
| x = x.transpose(1, 2).contiguous() | ||
| out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1]) | ||
| out = F.silu(out[:, :, :seq_len]).transpose(1, 2).contiguous() | ||
| return out, None |
There was a problem hiding this comment.
The _torch_causal_conv1d_fn fallback implementation has an inconsistent return type and ignores the activation argument. It returns a tensor at line 98 but a tuple (tensor, None) at line 103. Additionally, it hardcodes F.silu instead of using the provided activation parameter. This can lead to incorrect results if the model uses a different activation function.
def _torch_causal_conv1d_fn(
*,
x,
weight,
bias=None,
activation=None,
seq_idx=None,
backend=None,
cu_seqlens=None,
):
# Fallback priority:
# 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above.
# 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable.
# 3. plain torch conv1d is the final non-packed fallback.
del backend
if cu_seqlens is not None:
raise NotImplementedError(
'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires '
'flash-linear-attention. The torch fallback only supports non-packed sequences. '
'Please install flash-linear-attention or disable padding_free/packing.')
if origin_causal_conv1d_fn is not None:
out = origin_causal_conv1d_fn(
x=x.transpose(1, 2).contiguous(),
weight=weight,
bias=bias,
activation=activation,
seq_idx=seq_idx,
)
return out.transpose(1, 2).contiguous()
seq_len = x.shape[1]
x = x.transpose(1, 2).contiguous()
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1])
out = out[:, :, :seq_len]
if activation == 'silu':
out = F.silu(out)
return out.transpose(1, 2).contiguous()
PR type
PR information
This PR improves padding-free / packed-sequence support for Transformers backend, especially Qwen3.5.
Main changes:
position_idsas padding-free inputs in the processor pipeline.causal_conv1dandchunk_gated_delta_rule.cu_seq_lens_q.forward_only.cu_seqlensfrom the SP-adjustedposition_ids.flash-linear-attentionis unavailable.Experiment results
python -m py_compile src/twinkle/patch/qwen35_gdn_padding_free.pypython -m py_compile src/twinkle/model/transformers/transformers.pypython -m py_compile src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py