Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _extract_real_position_ids(position_ids: Optional[torch.Tensor]) -> Optional

def _update_packed_varlen_metadata(self, real_position_ids: Optional[torch.Tensor]) -> None:
self.extra_kwargs.pop('cu_seq_lens_q', None)
if real_position_ids is None or not self._is_packed_position_ids(real_position_ids):
if not self.extra_kwargs.get('padding_free', False) or real_position_ids is None:
return
position_ids = self._extract_real_position_ids(real_position_ids)
if position_ids is None or not torch.is_tensor(position_ids):
Expand Down Expand Up @@ -220,17 +220,19 @@ def _attention(query, key, value, *args, **kwargs):
window_size=kwargs.get('sliding_window') or (-1, -1),
group=self._rp_group,
)
elif self.extra_kwargs.get('is_packed', False) or 'cu_seq_lens_q' in kwargs:
cu_seqlens = kwargs.get('cu_seq_lens_q')
if cu_seqlens is None:
position_ids = kwargs.get('position_ids')
if position_ids is None:
position_ids = self.real_position_ids
position_ids = self._extract_real_position_ids(position_ids)
position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids)
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32)
else:
cu_seqlens = cu_seqlens.to(dtype=torch.int32, device=query.device)
elif self.extra_kwargs.get('padding_free', False) or 'cu_seq_lens_q' in kwargs:
position_ids = kwargs.get('position_ids')
if position_ids is None:
position_ids = self.real_position_ids
if position_ids is None:
raise ValueError('Packed/varlen flash_attention_2 requires position_ids to derive '
'cu_seq_lens_q.')
position_ids = self._extract_real_position_ids(position_ids)
position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids)
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(
dtype=torch.int32,
device=query.device,
)
Comment on lines +224 to +235
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for obtaining cu_seqlens has been changed to always derive it from position_ids, ignoring any cu_seq_lens_q provided in kwargs. This is a regression in flexibility and potentially correctness. If cu_seq_lens_q is provided by the processor, it likely describes the real sequence boundaries. Deriving it from the SP-padded position_ids (line 231) will include the padding tokens in the last sequence segment, causing Flash Attention to process them unnecessarily. It is recommended to restore the check for cu_seq_lens_q in kwargs first.

                        cu_seqlens = kwargs.get('cu_seq_lens_q')
                        if cu_seqlens is None:
                            position_ids = kwargs.get('position_ids')
                            if position_ids is None:
                                position_ids = self.real_position_ids
                            if position_ids is None:
                                raise ValueError('Packed/varlen flash_attention_2 requires position_ids to derive '
                                                 'cu_seq_lens_q.')
                            position_ids = self._extract_real_position_ids(position_ids)
                            position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids)
                            cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(
                                dtype=torch.int32,
                                device=query.device,
                            )
                        else:
                            cu_seqlens = cu_seqlens.to(dtype=torch.int32, device=query.device)

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
total_tokens = int(cu_seqlens[-1].item())
if query.shape[2] != total_tokens:
Expand All @@ -241,7 +243,7 @@ def _attention(query, key, value, *args, **kwargs):
kwargs['cu_seq_lens_k'] = cu_seqlens
kwargs['max_length_q'] = max_seqlen
kwargs['max_length_k'] = max_seqlen
if self.extra_kwargs.get('is_packed', False) and len(args) > 0:
if self.extra_kwargs.get('padding_free', False) and len(args) > 0:
args = (None, *args[1:])
return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args,
**kwargs)[0]
Expand All @@ -261,10 +263,10 @@ def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_sta
# Policy: packed (PackingDataset/padding-free) batches require FlashAttention2 varlen/packed semantics.
# SDPA does not have a native packed/varlen interface; supporting packed batches would require building a
# large block-diagonal causal mask (slow / memory heavy).
if self.extra_kwargs.get('is_packed', False):
if self.extra_kwargs.get('padding_free', False):
raise RuntimeError(
'SequenceParallel: detected packed batch (position_ids contains multiple sequences). '
'SDPA backend is not supported for packed batches; please use flash_attention_2.')
'SequenceParallel: detected padding_free/packed batch. '
'SDPA backend is not supported for padding_free/packed batches; please use flash_attention_2.')
if dist_attn.local_attn is None:

def _attention(query, key, value, *args, **kwargs):
Expand Down Expand Up @@ -688,8 +690,6 @@ def pad_and_split_inputs(self,
"""
tokenizer = self.tokenizer
real_position_ids = real_position_ids if real_position_ids is not None else position_ids
# Track packed batches to drive attention backend behavior (packed => require flash_attention_2 varlen).
self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids)
self._update_packed_varlen_metadata(real_position_ids)
extra_values = []
batch_size = input_ids.shape[
Expand Down Expand Up @@ -804,10 +804,16 @@ def prepare_inputs(self, inputs):
"""
input_ids = inputs.get('input_ids')
position_ids = inputs.get('position_ids')
padding_free = bool(inputs.pop('padding_free', False))
if padding_free and self.attn_implementation not in ('flash_attention_2', 'flash_attention_3'):
raise RuntimeError('Transformers SequenceParallel does not support padding_free/packed inputs with '
f'attn_implementation={self.attn_implementation!r}. '
'Use flash_attention_2 or flash_attention_3, or disable padding_free/packing. '
'SDPA/eager attention cannot safely preserve packed sequence boundaries in this path.')
real_position_ids = self._extract_real_position_ids(position_ids)
if real_position_ids is not None and input_ids is not None and real_position_ids.shape[0] == input_ids.shape[0]:
self.extra_kwargs['position_ids'] = real_position_ids.clone()
self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids)
self.extra_kwargs['padding_free'] = padding_free
self._update_packed_varlen_metadata(real_position_ids)
if input_ids is not None:
self.extra_kwargs['input_ids'] = input_ids.clone()
Expand Down Expand Up @@ -919,6 +925,23 @@ def postprocess_outputs(self, outputs: Any) -> Any:
outputs['logits'] = gathered
return outputs

@staticmethod
def _trim_gathered_sequence_padding(tensor: torch.Tensor, real_position_ids: torch.Tensor) -> torch.Tensor:
if real_position_ids is None or not torch.is_tensor(real_position_ids) or real_position_ids.dim() < 2:
return tensor
if sequence_parallel.rp_world_size > 1:
cu_seqlens = get_cu_seqlens_from_position_ids(real_position_ids)
pieces = []
padded_offset = 0
divisor = sequence_parallel.world_size * 2
for i in range(len(cu_seqlens) - 1):
real_len = int((cu_seqlens[i + 1] - cu_seqlens[i]).item())
padded_len = math.ceil(real_len / divisor) * divisor
pieces.append(tensor[:, padded_offset:padded_offset + real_len])
padded_offset += padded_len
return torch.cat(pieces, dim=1).contiguous() if pieces else tensor[:, :0].contiguous()
return tensor[:, :real_position_ids.shape[-1]].contiguous()

def gather_loss_tensors(
self,
inputs: Dict[str, Any],
Expand All @@ -939,6 +962,8 @@ def gather_loss_tensors(
outputs = copy(outputs)
real_position_ids = sequence_parallel.real_position_ids
gathered_logps, gathered_labels = GatherLoss.apply(logps, labels, 1, real_position_ids)
gathered_logps = self._trim_gathered_sequence_padding(gathered_logps, real_position_ids)
gathered_labels = self._trim_gathered_sequence_padding(gathered_labels, real_position_ids)
outputs['logps'] = gathered_logps
inputs['labels'] = gathered_labels
return inputs, outputs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers.utils.import_utils import is_flash_linear_attention_available
import warnings
from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
from typing import Any, Optional, Tuple

from twinkle.model.transformers.strategy.sequence_parallel.utils import head_to_seq_shard, seq_to_head_shard
from twinkle.model.transformers.strategy.sequence_parallel.utils import (
get_packed_cu_seqlens_from_sequence_parallel_context, head_to_seq_shard, seq_to_head_shard)
from twinkle.patch import Patch

if is_flash_linear_attention_available():
Expand All @@ -14,8 +16,15 @@
_FLA_CAUSAL_CONV1D_FN = None
_FLA_CHUNK_GATED_DELTA_RULE = None

_SP_LINEAR_KERNEL_IMPORT_ERROR = ('Qwen3.5 linear attention sequence parallel requires flash-linear-attention. '
'Install: https://github.com/fla-org/flash-linear-attention#installation')
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn as _CAUSAL_CONV1D_FN
else:
_CAUSAL_CONV1D_FN = None

_SP_LINEAR_KERNEL_FALLBACK_WARNING = (
'flash-linear-attention is not available; falling back to torch implementations for Qwen3.5 linear attention '
'sequence parallel. This fallback only supports non-packed sequences.')
_SP_LINEAR_KERNEL_FALLBACK_WARNED = False


def _sp_is_enabled(sequence_parallel_context) -> bool:
Expand Down Expand Up @@ -45,10 +54,67 @@ def _get_local_padding_mask(


def _ensure_linear_attention_kernels(mod: torch.nn.Module):
mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN
mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE
if mod.chunk_gated_delta_rule is None or mod.causal_conv1d_fn is None:
raise ImportError(_SP_LINEAR_KERNEL_IMPORT_ERROR)
if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None:
mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN
mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE
return False

from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule
origin_causal_conv1d_fn = _CAUSAL_CONV1D_FN or getattr(mod, '_twinkle_origin_causal_conv1d_fn', None)
if origin_causal_conv1d_fn is None:
origin_causal_conv1d_fn = getattr(mod, 'causal_conv1d_fn', None)
if getattr(origin_causal_conv1d_fn, '_twinkle_torch_fallback', False):
origin_causal_conv1d_fn = None
mod._twinkle_origin_causal_conv1d_fn = origin_causal_conv1d_fn

def _torch_causal_conv1d_fn(
*,
x,
weight,
bias=None,
activation=None,
seq_idx=None,
backend=None,
cu_seqlens=None,
):
# Fallback priority:
# 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above.
# 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable.
# 3. plain torch conv1d is the final non-packed fallback.
del backend
if cu_seqlens is not None:
raise NotImplementedError(
'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires '
'flash-linear-attention. The torch fallback only supports non-packed sequences. '
'Please install flash-linear-attention or disable padding_free/packing.')
if origin_causal_conv1d_fn is not None:
out = origin_causal_conv1d_fn(
x=x.transpose(1, 2).contiguous(),
weight=weight,
bias=bias,
activation=activation,
seq_idx=seq_idx,
)
return out.transpose(1, 2).contiguous()
seq_len = x.shape[1]
x = x.transpose(1, 2).contiguous()
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1])
out = F.silu(out[:, :, :seq_len]).transpose(1, 2).contiguous()
return out, None
Comment on lines +70 to +103
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _torch_causal_conv1d_fn fallback implementation has an inconsistent return type and ignores the activation argument. It returns a tensor at line 98 but a tuple (tensor, None) at line 103. Additionally, it hardcodes F.silu instead of using the provided activation parameter. This can lead to incorrect results if the model uses a different activation function.

    def _torch_causal_conv1d_fn(
        *,
        x,
        weight,
        bias=None,
        activation=None,
        seq_idx=None,
        backend=None,
        cu_seqlens=None,
    ):
        # Fallback priority:
        # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above.
        # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable.
        # 3. plain torch conv1d is the final non-packed fallback.
        del backend
        if cu_seqlens is not None:
            raise NotImplementedError(
                'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires '
                'flash-linear-attention. The torch fallback only supports non-packed sequences. '
                'Please install flash-linear-attention or disable padding_free/packing.')
        if origin_causal_conv1d_fn is not None:
            out = origin_causal_conv1d_fn(
                x=x.transpose(1, 2).contiguous(),
                weight=weight,
                bias=bias,
                activation=activation,
                seq_idx=seq_idx,
            )
            return out.transpose(1, 2).contiguous()
        seq_len = x.shape[1]
        x = x.transpose(1, 2).contiguous()
        out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1])
        out = out[:, :, :seq_len]
        if activation == 'silu':
            out = F.silu(out)
        return out.transpose(1, 2).contiguous()


_torch_causal_conv1d_fn._twinkle_torch_fallback = True
mod.causal_conv1d_fn = _torch_causal_conv1d_fn
mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule
_warn_linear_attention_kernel_fallback_once()
return True


def _warn_linear_attention_kernel_fallback_once():
global _SP_LINEAR_KERNEL_FALLBACK_WARNED
if _SP_LINEAR_KERNEL_FALLBACK_WARNED:
return
warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2)
_SP_LINEAR_KERNEL_FALLBACK_WARNED = True


def _get_local_conv_weights(
Expand Down Expand Up @@ -90,10 +156,9 @@ def _run_forward(
cache_params=None,
cache_position=None,
attention_mask: Optional[torch.Tensor] = None,
cu_seq_lens_q: Optional[torch.Tensor] = None,
sequence_parallel_context=None,
) -> torch.Tensor:
_ensure_linear_attention_kernels(mod)
using_torch_fallback = _ensure_linear_attention_kernels(mod)
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states

local_attention_mask = attention_mask
Expand Down Expand Up @@ -159,22 +224,22 @@ def _run_forward(
conv_weight = mod.conv1d.weight.squeeze(1)
conv_bias = getattr(mod.conv1d, 'bias', None)

packed_cu_seqlens = None
if cu_seq_lens_q is not None:
packed_cu_seqlens = cu_seq_lens_q.to(dtype=torch.int32, device=mixed_qkv.device)
elif sequence_parallel_context is not None:
packed_cu_seqlens = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q')
if packed_cu_seqlens is not None:
packed_cu_seqlens = packed_cu_seqlens.to(dtype=torch.int32, device=mixed_qkv.device)
if bool(getattr(sequence_parallel_context, 'extra_kwargs', {}).get('is_packed',
False)) and packed_cu_seqlens is None:
raise ValueError(
'Packed Qwen3.5 linear attention sequence parallel requires cu_seq_lens_q to be populated by '
'sequence parallel input preparation.')
packed_cu_seqlens = get_packed_cu_seqlens_from_sequence_parallel_context(
sequence_parallel_context,
device=mixed_qkv.device,
)
extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {})
if bool(extra_kwargs.get('padding_free', False)) and packed_cu_seqlens is None:
raise ValueError('Qwen3.5 sequence parallel with padding_free/packed inputs requires cu_seq_lens_q.')
if using_torch_fallback and packed_cu_seqlens is not None:
raise NotImplementedError(
'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires '
'flash-linear-attention. The torch fallback only supports non-packed sequences. '
'Please install flash-linear-attention or disable padding_free/packing.')
if cache_params is not None:
cache_params.conv_states[mod.layer_idx] = F.pad(
mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0))
mixed_qkv, _ = mod.causal_conv1d_fn(
mixed_qkv = mod.causal_conv1d_fn(
x=mixed_qkv,
weight=conv_weight,
bias=conv_bias,
Expand All @@ -183,6 +248,8 @@ def _run_forward(
backend='triton',
cu_seqlens=packed_cu_seqlens,
)
if isinstance(mixed_qkv, tuple):
mixed_qkv = mixed_qkv[0]
if mixed_qkv.dim() == 2:
mixed_qkv = mixed_qkv.unsqueeze(0)
if mixed_qkv.dim() != 3:
Expand Down Expand Up @@ -253,9 +320,6 @@ def sp_linear_forward(
**extra_kwargs,
):
sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel)
cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None)
if cu_seq_lens_q is None and sequence_parallel_context is not None:
cu_seq_lens_q = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q')
if not _sp_is_enabled(sequence_parallel_context):
return origin_forward(
mod,
Expand All @@ -270,7 +334,6 @@ def sp_linear_forward(
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
cu_seq_lens_q=cu_seq_lens_q,
sequence_parallel_context=sequence_parallel_context,
)

Expand Down
20 changes: 20 additions & 0 deletions src/twinkle/model/transformers/strategy/sequence_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
return cu_seqlens


def get_packed_cu_seqlens_from_sequence_parallel_context(
sequence_parallel_context,
*,
device: torch.device,
) -> Optional[torch.Tensor]:
if sequence_parallel_context is None:
return None

extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {})
if extra_kwargs.get('padding_free', False):
position_ids = getattr(sequence_parallel_context, 'real_position_ids', None)
if position_ids is not None:
position_ids = sequence_parallel_context._extract_real_position_ids(position_ids)
position_ids = sequence_parallel_context.pad(position_ids, padding_value=-1, position_ids=position_ids)
return get_cu_seqlens_from_position_ids(position_ids).to(dtype=torch.int32, device=device)

cu_seqlens = extra_kwargs.get('cu_seq_lens_q')
return cu_seqlens.to(dtype=torch.int32, device=device) if cu_seqlens is not None else None


def _get_raw_data_world_size(device_mesh: DeviceMesh) -> int:
dp_world_size = device_mesh.dp_world_size or 1
fsdp_world_size = device_mesh.fsdp_world_size or 1
Expand Down
Loading
Loading