-
Notifications
You must be signed in to change notification settings - Fork 27
padding-free / packed-sequence support for Qwen3.5. #186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.nn.functional as F | ||
| from transformers.utils.import_utils import is_flash_linear_attention_available | ||
| import warnings | ||
| from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available | ||
| from typing import Any, Optional, Tuple | ||
|
|
||
| from twinkle.model.transformers.strategy.sequence_parallel.utils import head_to_seq_shard, seq_to_head_shard | ||
| from twinkle.model.transformers.strategy.sequence_parallel.utils import ( | ||
| get_packed_cu_seqlens_from_sequence_parallel_context, head_to_seq_shard, seq_to_head_shard) | ||
| from twinkle.patch import Patch | ||
|
|
||
| if is_flash_linear_attention_available(): | ||
|
|
@@ -14,8 +16,15 @@ | |
| _FLA_CAUSAL_CONV1D_FN = None | ||
| _FLA_CHUNK_GATED_DELTA_RULE = None | ||
|
|
||
| _SP_LINEAR_KERNEL_IMPORT_ERROR = ('Qwen3.5 linear attention sequence parallel requires flash-linear-attention. ' | ||
| 'Install: https://github.com/fla-org/flash-linear-attention#installation') | ||
| if is_causal_conv1d_available(): | ||
| from causal_conv1d import causal_conv1d_fn as _CAUSAL_CONV1D_FN | ||
| else: | ||
| _CAUSAL_CONV1D_FN = None | ||
|
|
||
| _SP_LINEAR_KERNEL_FALLBACK_WARNING = ( | ||
| 'flash-linear-attention is not available; falling back to torch implementations for Qwen3.5 linear attention ' | ||
| 'sequence parallel. This fallback only supports non-packed sequences.') | ||
| _SP_LINEAR_KERNEL_FALLBACK_WARNED = False | ||
|
|
||
|
|
||
| def _sp_is_enabled(sequence_parallel_context) -> bool: | ||
|
|
@@ -45,10 +54,67 @@ def _get_local_padding_mask( | |
|
|
||
|
|
||
| def _ensure_linear_attention_kernels(mod: torch.nn.Module): | ||
| mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN | ||
| mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE | ||
| if mod.chunk_gated_delta_rule is None or mod.causal_conv1d_fn is None: | ||
| raise ImportError(_SP_LINEAR_KERNEL_IMPORT_ERROR) | ||
| if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: | ||
| mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN | ||
| mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE | ||
| return False | ||
|
|
||
| from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule | ||
| origin_causal_conv1d_fn = _CAUSAL_CONV1D_FN or getattr(mod, '_twinkle_origin_causal_conv1d_fn', None) | ||
| if origin_causal_conv1d_fn is None: | ||
| origin_causal_conv1d_fn = getattr(mod, 'causal_conv1d_fn', None) | ||
| if getattr(origin_causal_conv1d_fn, '_twinkle_torch_fallback', False): | ||
| origin_causal_conv1d_fn = None | ||
| mod._twinkle_origin_causal_conv1d_fn = origin_causal_conv1d_fn | ||
|
|
||
| def _torch_causal_conv1d_fn( | ||
| *, | ||
| x, | ||
| weight, | ||
| bias=None, | ||
| activation=None, | ||
| seq_idx=None, | ||
| backend=None, | ||
| cu_seqlens=None, | ||
| ): | ||
| # Fallback priority: | ||
| # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above. | ||
| # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable. | ||
| # 3. plain torch conv1d is the final non-packed fallback. | ||
| del backend | ||
| if cu_seqlens is not None: | ||
| raise NotImplementedError( | ||
| 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' | ||
| 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' | ||
| 'Please install flash-linear-attention or disable padding_free/packing.') | ||
| if origin_causal_conv1d_fn is not None: | ||
| out = origin_causal_conv1d_fn( | ||
| x=x.transpose(1, 2).contiguous(), | ||
| weight=weight, | ||
| bias=bias, | ||
| activation=activation, | ||
| seq_idx=seq_idx, | ||
| ) | ||
| return out.transpose(1, 2).contiguous() | ||
| seq_len = x.shape[1] | ||
| x = x.transpose(1, 2).contiguous() | ||
| out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1]) | ||
| out = F.silu(out[:, :, :seq_len]).transpose(1, 2).contiguous() | ||
| return out, None | ||
|
Comment on lines
+70
to
+103
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def _torch_causal_conv1d_fn(
*,
x,
weight,
bias=None,
activation=None,
seq_idx=None,
backend=None,
cu_seqlens=None,
):
# Fallback priority:
# 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above.
# 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable.
# 3. plain torch conv1d is the final non-packed fallback.
del backend
if cu_seqlens is not None:
raise NotImplementedError(
'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires '
'flash-linear-attention. The torch fallback only supports non-packed sequences. '
'Please install flash-linear-attention or disable padding_free/packing.')
if origin_causal_conv1d_fn is not None:
out = origin_causal_conv1d_fn(
x=x.transpose(1, 2).contiguous(),
weight=weight,
bias=bias,
activation=activation,
seq_idx=seq_idx,
)
return out.transpose(1, 2).contiguous()
seq_len = x.shape[1]
x = x.transpose(1, 2).contiguous()
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1])
out = out[:, :, :seq_len]
if activation == 'silu':
out = F.silu(out)
return out.transpose(1, 2).contiguous() |
||
|
|
||
| _torch_causal_conv1d_fn._twinkle_torch_fallback = True | ||
| mod.causal_conv1d_fn = _torch_causal_conv1d_fn | ||
| mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule | ||
| _warn_linear_attention_kernel_fallback_once() | ||
| return True | ||
|
|
||
|
|
||
| def _warn_linear_attention_kernel_fallback_once(): | ||
| global _SP_LINEAR_KERNEL_FALLBACK_WARNED | ||
| if _SP_LINEAR_KERNEL_FALLBACK_WARNED: | ||
| return | ||
| warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2) | ||
| _SP_LINEAR_KERNEL_FALLBACK_WARNED = True | ||
|
|
||
|
|
||
| def _get_local_conv_weights( | ||
|
|
@@ -90,10 +156,9 @@ def _run_forward( | |
| cache_params=None, | ||
| cache_position=None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| cu_seq_lens_q: Optional[torch.Tensor] = None, | ||
| sequence_parallel_context=None, | ||
| ) -> torch.Tensor: | ||
| _ensure_linear_attention_kernels(mod) | ||
| using_torch_fallback = _ensure_linear_attention_kernels(mod) | ||
| from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states | ||
|
|
||
| local_attention_mask = attention_mask | ||
|
|
@@ -159,22 +224,22 @@ def _run_forward( | |
| conv_weight = mod.conv1d.weight.squeeze(1) | ||
| conv_bias = getattr(mod.conv1d, 'bias', None) | ||
|
|
||
| packed_cu_seqlens = None | ||
| if cu_seq_lens_q is not None: | ||
| packed_cu_seqlens = cu_seq_lens_q.to(dtype=torch.int32, device=mixed_qkv.device) | ||
| elif sequence_parallel_context is not None: | ||
| packed_cu_seqlens = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q') | ||
| if packed_cu_seqlens is not None: | ||
| packed_cu_seqlens = packed_cu_seqlens.to(dtype=torch.int32, device=mixed_qkv.device) | ||
| if bool(getattr(sequence_parallel_context, 'extra_kwargs', {}).get('is_packed', | ||
| False)) and packed_cu_seqlens is None: | ||
| raise ValueError( | ||
| 'Packed Qwen3.5 linear attention sequence parallel requires cu_seq_lens_q to be populated by ' | ||
| 'sequence parallel input preparation.') | ||
| packed_cu_seqlens = get_packed_cu_seqlens_from_sequence_parallel_context( | ||
| sequence_parallel_context, | ||
| device=mixed_qkv.device, | ||
| ) | ||
| extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {}) | ||
| if bool(extra_kwargs.get('padding_free', False)) and packed_cu_seqlens is None: | ||
| raise ValueError('Qwen3.5 sequence parallel with padding_free/packed inputs requires cu_seq_lens_q.') | ||
| if using_torch_fallback and packed_cu_seqlens is not None: | ||
| raise NotImplementedError( | ||
| 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' | ||
| 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' | ||
| 'Please install flash-linear-attention or disable padding_free/packing.') | ||
| if cache_params is not None: | ||
| cache_params.conv_states[mod.layer_idx] = F.pad( | ||
| mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0)) | ||
| mixed_qkv, _ = mod.causal_conv1d_fn( | ||
| mixed_qkv = mod.causal_conv1d_fn( | ||
| x=mixed_qkv, | ||
| weight=conv_weight, | ||
| bias=conv_bias, | ||
|
|
@@ -183,6 +248,8 @@ def _run_forward( | |
| backend='triton', | ||
| cu_seqlens=packed_cu_seqlens, | ||
| ) | ||
| if isinstance(mixed_qkv, tuple): | ||
| mixed_qkv = mixed_qkv[0] | ||
| if mixed_qkv.dim() == 2: | ||
| mixed_qkv = mixed_qkv.unsqueeze(0) | ||
| if mixed_qkv.dim() != 3: | ||
|
|
@@ -253,9 +320,6 @@ def sp_linear_forward( | |
| **extra_kwargs, | ||
| ): | ||
| sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) | ||
| cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) | ||
| if cu_seq_lens_q is None and sequence_parallel_context is not None: | ||
| cu_seq_lens_q = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q') | ||
| if not _sp_is_enabled(sequence_parallel_context): | ||
| return origin_forward( | ||
| mod, | ||
|
|
@@ -270,7 +334,6 @@ def sp_linear_forward( | |
| cache_params=cache_params, | ||
| cache_position=cache_position, | ||
| attention_mask=attention_mask, | ||
| cu_seq_lens_q=cu_seq_lens_q, | ||
| sequence_parallel_context=sequence_parallel_context, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for obtaining
cu_seqlenshas been changed to always derive it fromposition_ids, ignoring anycu_seq_lens_qprovided inkwargs. This is a regression in flexibility and potentially correctness. Ifcu_seq_lens_qis provided by the processor, it likely describes the real sequence boundaries. Deriving it from the SP-paddedposition_ids(line 231) will include the padding tokens in the last sequence segment, causing Flash Attention to process them unnecessarily. It is recommended to restore the check forcu_seq_lens_qinkwargsfirst.