From bc2663d71daeda620dd952b68556c7adcecc11ee Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Thu, 30 Apr 2026 16:51:47 +0800 Subject: [PATCH 1/2] refactor: replace `is_packed` with `padding_free` flag for clarity The `is_packed` flag was ambiguous and only inferred from position IDs. Now `padding_free` is explicitly passed as input, making the intent clearer and enabling early validation of attention backend compatibility. --- .../strategy/sequence_parallel/__init__.py | 63 ++++-- .../sequence_parallel/linear_attention_sp.py | 117 ++++++++--- .../strategy/sequence_parallel/utils.py | 20 ++ .../model/transformers/transformers.py | 16 +- src/twinkle/patch/qwen35_gdn_padding_free.py | 194 ++++++++++++++++++ src/twinkle/processor/base.py | 43 +++- 6 files changed, 403 insertions(+), 50 deletions(-) create mode 100644 src/twinkle/patch/qwen35_gdn_padding_free.py diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 3c6f1a4e..49f7050d 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -67,7 +67,7 @@ def _extract_real_position_ids(position_ids: Optional[torch.Tensor]) -> Optional def _update_packed_varlen_metadata(self, real_position_ids: Optional[torch.Tensor]) -> None: self.extra_kwargs.pop('cu_seq_lens_q', None) - if real_position_ids is None or not self._is_packed_position_ids(real_position_ids): + if not self.extra_kwargs.get('padding_free', False) or real_position_ids is None: return position_ids = self._extract_real_position_ids(real_position_ids) if position_ids is None or not torch.is_tensor(position_ids): @@ -220,17 +220,19 @@ def _attention(query, key, value, *args, **kwargs): window_size=kwargs.get('sliding_window') or (-1, -1), group=self._rp_group, ) - elif self.extra_kwargs.get('is_packed', False) or 'cu_seq_lens_q' in kwargs: - cu_seqlens = kwargs.get('cu_seq_lens_q') - if cu_seqlens is None: - position_ids = kwargs.get('position_ids') - if position_ids is None: - position_ids = self.real_position_ids - position_ids = self._extract_real_position_ids(position_ids) - position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids) - cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32) - else: - cu_seqlens = cu_seqlens.to(dtype=torch.int32, device=query.device) + elif self.extra_kwargs.get('padding_free', False) or 'cu_seq_lens_q' in kwargs: + position_ids = kwargs.get('position_ids') + if position_ids is None: + position_ids = self.real_position_ids + if position_ids is None: + raise ValueError('Packed/varlen flash_attention_2 requires position_ids to derive ' + 'cu_seq_lens_q.') + position_ids = self._extract_real_position_ids(position_ids) + position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids) + cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to( + dtype=torch.int32, + device=query.device, + ) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() total_tokens = int(cu_seqlens[-1].item()) if query.shape[2] != total_tokens: @@ -241,7 +243,7 @@ def _attention(query, key, value, *args, **kwargs): kwargs['cu_seq_lens_k'] = cu_seqlens kwargs['max_length_q'] = max_seqlen kwargs['max_length_k'] = max_seqlen - if self.extra_kwargs.get('is_packed', False) and len(args) > 0: + if self.extra_kwargs.get('padding_free', False) and len(args) > 0: args = (None, *args[1:]) return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args, **kwargs)[0] @@ -261,10 +263,10 @@ def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_sta # Policy: packed (PackingDataset/padding-free) batches require FlashAttention2 varlen/packed semantics. # SDPA does not have a native packed/varlen interface; supporting packed batches would require building a # large block-diagonal causal mask (slow / memory heavy). - if self.extra_kwargs.get('is_packed', False): + if self.extra_kwargs.get('padding_free', False): raise RuntimeError( - 'SequenceParallel: detected packed batch (position_ids contains multiple sequences). ' - 'SDPA backend is not supported for packed batches; please use flash_attention_2.') + 'SequenceParallel: detected padding_free/packed batch. ' + 'SDPA backend is not supported for padding_free/packed batches; please use flash_attention_2.') if dist_attn.local_attn is None: def _attention(query, key, value, *args, **kwargs): @@ -688,8 +690,6 @@ def pad_and_split_inputs(self, """ tokenizer = self.tokenizer real_position_ids = real_position_ids if real_position_ids is not None else position_ids - # Track packed batches to drive attention backend behavior (packed => require flash_attention_2 varlen). - self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids) self._update_packed_varlen_metadata(real_position_ids) extra_values = [] batch_size = input_ids.shape[ @@ -804,10 +804,16 @@ def prepare_inputs(self, inputs): """ input_ids = inputs.get('input_ids') position_ids = inputs.get('position_ids') + padding_free = bool(inputs.pop('padding_free', False)) + if padding_free and self.attn_implementation not in ('flash_attention_2', 'flash_attention_3'): + raise RuntimeError('Transformers SequenceParallel does not support padding_free/packed inputs with ' + f'attn_implementation={self.attn_implementation!r}. ' + 'Use flash_attention_2 or flash_attention_3, or disable padding_free/packing. ' + 'SDPA/eager attention cannot safely preserve packed sequence boundaries in this path.') real_position_ids = self._extract_real_position_ids(position_ids) if real_position_ids is not None and input_ids is not None and real_position_ids.shape[0] == input_ids.shape[0]: self.extra_kwargs['position_ids'] = real_position_ids.clone() - self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids) + self.extra_kwargs['padding_free'] = padding_free self._update_packed_varlen_metadata(real_position_ids) if input_ids is not None: self.extra_kwargs['input_ids'] = input_ids.clone() @@ -919,6 +925,23 @@ def postprocess_outputs(self, outputs: Any) -> Any: outputs['logits'] = gathered return outputs + @staticmethod + def _trim_gathered_sequence_padding(tensor: torch.Tensor, real_position_ids: torch.Tensor) -> torch.Tensor: + if real_position_ids is None or not torch.is_tensor(real_position_ids) or real_position_ids.dim() < 2: + return tensor + if sequence_parallel.rp_world_size > 1: + cu_seqlens = get_cu_seqlens_from_position_ids(real_position_ids) + pieces = [] + padded_offset = 0 + divisor = sequence_parallel.world_size * 2 + for i in range(len(cu_seqlens) - 1): + real_len = int((cu_seqlens[i + 1] - cu_seqlens[i]).item()) + padded_len = math.ceil(real_len / divisor) * divisor + pieces.append(tensor[:, padded_offset:padded_offset + real_len]) + padded_offset += padded_len + return torch.cat(pieces, dim=1).contiguous() if pieces else tensor[:, :0].contiguous() + return tensor[:, :real_position_ids.shape[-1]].contiguous() + def gather_loss_tensors( self, inputs: Dict[str, Any], @@ -939,6 +962,8 @@ def gather_loss_tensors( outputs = copy(outputs) real_position_ids = sequence_parallel.real_position_ids gathered_logps, gathered_labels = GatherLoss.apply(logps, labels, 1, real_position_ids) + gathered_logps = self._trim_gathered_sequence_padding(gathered_logps, real_position_ids) + gathered_labels = self._trim_gathered_sequence_padding(gathered_labels, real_position_ids) outputs['logps'] = gathered_logps inputs['labels'] = gathered_labels return inputs, outputs diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py index 7b0d960b..23d4c24a 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -1,10 +1,12 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from transformers.utils.import_utils import is_flash_linear_attention_available +import warnings +from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from typing import Any, Optional, Tuple -from twinkle.model.transformers.strategy.sequence_parallel.utils import head_to_seq_shard, seq_to_head_shard +from twinkle.model.transformers.strategy.sequence_parallel.utils import ( + get_packed_cu_seqlens_from_sequence_parallel_context, head_to_seq_shard, seq_to_head_shard) from twinkle.patch import Patch if is_flash_linear_attention_available(): @@ -14,8 +16,15 @@ _FLA_CAUSAL_CONV1D_FN = None _FLA_CHUNK_GATED_DELTA_RULE = None -_SP_LINEAR_KERNEL_IMPORT_ERROR = ('Qwen3.5 linear attention sequence parallel requires flash-linear-attention. ' - 'Install: https://github.com/fla-org/flash-linear-attention#installation') +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn as _CAUSAL_CONV1D_FN +else: + _CAUSAL_CONV1D_FN = None + +_SP_LINEAR_KERNEL_FALLBACK_WARNING = ( + 'flash-linear-attention is not available; falling back to torch implementations for Qwen3.5 linear attention ' + 'sequence parallel. This fallback only supports non-packed sequences.') +_SP_LINEAR_KERNEL_FALLBACK_WARNED = False def _sp_is_enabled(sequence_parallel_context) -> bool: @@ -45,10 +54,67 @@ def _get_local_padding_mask( def _ensure_linear_attention_kernels(mod: torch.nn.Module): - mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN - mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE - if mod.chunk_gated_delta_rule is None or mod.causal_conv1d_fn is None: - raise ImportError(_SP_LINEAR_KERNEL_IMPORT_ERROR) + if _FLA_CAUSAL_CONV1D_FN is not None and _FLA_CHUNK_GATED_DELTA_RULE is not None: + mod.causal_conv1d_fn = _FLA_CAUSAL_CONV1D_FN + mod.chunk_gated_delta_rule = _FLA_CHUNK_GATED_DELTA_RULE + return False + + from transformers.models.qwen3_5.modeling_qwen3_5 import torch_chunk_gated_delta_rule + origin_causal_conv1d_fn = _CAUSAL_CONV1D_FN or getattr(mod, '_twinkle_origin_causal_conv1d_fn', None) + if origin_causal_conv1d_fn is None: + origin_causal_conv1d_fn = getattr(mod, 'causal_conv1d_fn', None) + if getattr(origin_causal_conv1d_fn, '_twinkle_torch_fallback', False): + origin_causal_conv1d_fn = None + mod._twinkle_origin_causal_conv1d_fn = origin_causal_conv1d_fn + + def _torch_causal_conv1d_fn( + *, + x, + weight, + bias=None, + activation=None, + seq_idx=None, + backend=None, + cu_seqlens=None, + ): + # Fallback priority: + # 1. flash-linear-attention kernels handle padding_free/packed cu_seqlens and are selected above. + # 2. causal-conv1d package accelerates non-packed convolution when flash-linear-attention is unavailable. + # 3. plain torch conv1d is the final non-packed fallback. + del backend + if cu_seqlens is not None: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' + 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' + 'Please install flash-linear-attention or disable padding_free/packing.') + if origin_causal_conv1d_fn is not None: + out = origin_causal_conv1d_fn( + x=x.transpose(1, 2).contiguous(), + weight=weight, + bias=bias, + activation=activation, + seq_idx=seq_idx, + ) + return out.transpose(1, 2).contiguous() + seq_len = x.shape[1] + x = x.transpose(1, 2).contiguous() + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=weight.shape[-1] - 1, groups=x.shape[1]) + out = F.silu(out[:, :, :seq_len]).transpose(1, 2).contiguous() + return out, None + + _torch_causal_conv1d_fn._twinkle_torch_fallback = True + mod.causal_conv1d_fn = _torch_causal_conv1d_fn + mod.chunk_gated_delta_rule = torch_chunk_gated_delta_rule + _warn_linear_attention_kernel_fallback_once() + return True + + +def _warn_linear_attention_kernel_fallback_once(): + global _SP_LINEAR_KERNEL_FALLBACK_WARNED + if _SP_LINEAR_KERNEL_FALLBACK_WARNED: + return + warnings.warn(_SP_LINEAR_KERNEL_FALLBACK_WARNING, stacklevel=2) + _SP_LINEAR_KERNEL_FALLBACK_WARNED = True def _get_local_conv_weights( @@ -90,10 +156,9 @@ def _run_forward( cache_params=None, cache_position=None, attention_mask: Optional[torch.Tensor] = None, - cu_seq_lens_q: Optional[torch.Tensor] = None, sequence_parallel_context=None, ) -> torch.Tensor: - _ensure_linear_attention_kernels(mod) + using_torch_fallback = _ensure_linear_attention_kernels(mod) from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states local_attention_mask = attention_mask @@ -159,22 +224,22 @@ def _run_forward( conv_weight = mod.conv1d.weight.squeeze(1) conv_bias = getattr(mod.conv1d, 'bias', None) - packed_cu_seqlens = None - if cu_seq_lens_q is not None: - packed_cu_seqlens = cu_seq_lens_q.to(dtype=torch.int32, device=mixed_qkv.device) - elif sequence_parallel_context is not None: - packed_cu_seqlens = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q') - if packed_cu_seqlens is not None: - packed_cu_seqlens = packed_cu_seqlens.to(dtype=torch.int32, device=mixed_qkv.device) - if bool(getattr(sequence_parallel_context, 'extra_kwargs', {}).get('is_packed', - False)) and packed_cu_seqlens is None: - raise ValueError( - 'Packed Qwen3.5 linear attention sequence parallel requires cu_seq_lens_q to be populated by ' - 'sequence parallel input preparation.') + packed_cu_seqlens = get_packed_cu_seqlens_from_sequence_parallel_context( + sequence_parallel_context, + device=mixed_qkv.device, + ) + extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {}) + if bool(extra_kwargs.get('padding_free', False)) and packed_cu_seqlens is None: + raise ValueError('Qwen3.5 sequence parallel with padding_free/packed inputs requires cu_seq_lens_q.') + if using_torch_fallback and packed_cu_seqlens is not None: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel with padding_free/packed inputs requires ' + 'flash-linear-attention. The torch fallback only supports non-packed sequences. ' + 'Please install flash-linear-attention or disable padding_free/packing.') if cache_params is not None: cache_params.conv_states[mod.layer_idx] = F.pad( mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0)) - mixed_qkv, _ = mod.causal_conv1d_fn( + mixed_qkv = mod.causal_conv1d_fn( x=mixed_qkv, weight=conv_weight, bias=conv_bias, @@ -183,6 +248,8 @@ def _run_forward( backend='triton', cu_seqlens=packed_cu_seqlens, ) + if isinstance(mixed_qkv, tuple): + mixed_qkv = mixed_qkv[0] if mixed_qkv.dim() == 2: mixed_qkv = mixed_qkv.unsqueeze(0) if mixed_qkv.dim() != 3: @@ -253,9 +320,6 @@ def sp_linear_forward( **extra_kwargs, ): sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) - cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) - if cu_seq_lens_q is None and sequence_parallel_context is not None: - cu_seq_lens_q = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q') if not _sp_is_enabled(sequence_parallel_context): return origin_forward( mod, @@ -270,7 +334,6 @@ def sp_linear_forward( cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask, - cu_seq_lens_q=cu_seq_lens_q, sequence_parallel_context=sequence_parallel_context, ) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/utils.py b/src/twinkle/model/transformers/strategy/sequence_parallel/utils.py index adbb945d..9832778d 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/utils.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/utils.py @@ -34,6 +34,26 @@ def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): return cu_seqlens +def get_packed_cu_seqlens_from_sequence_parallel_context( + sequence_parallel_context, + *, + device: torch.device, +) -> Optional[torch.Tensor]: + if sequence_parallel_context is None: + return None + + extra_kwargs = getattr(sequence_parallel_context, 'extra_kwargs', {}) + if extra_kwargs.get('padding_free', False): + position_ids = getattr(sequence_parallel_context, 'real_position_ids', None) + if position_ids is not None: + position_ids = sequence_parallel_context._extract_real_position_ids(position_ids) + position_ids = sequence_parallel_context.pad(position_ids, padding_value=-1, position_ids=position_ids) + return get_cu_seqlens_from_position_ids(position_ids).to(dtype=torch.int32, device=device) + + cu_seqlens = extra_kwargs.get('cu_seq_lens_q') + return cu_seqlens.to(dtype=torch.int32, device=device) if cu_seqlens is not None else None + + def _get_raw_data_world_size(device_mesh: DeviceMesh) -> int: dp_world_size = device_mesh.dp_world_size or 1 fsdp_world_size = device_mesh.fsdp_world_size or 1 diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index cbf56c3c..aedf8717 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -369,7 +369,13 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec loss_instance = optimizer_config.loss_instance loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding' - inputs: Dict[str, Any] = processor(inputs, sp_strategy=self.sp_strategy) + inputs: Dict[str, Any] = processor( + inputs, + sp_strategy=self.sp_strategy, + model=self.model, + hf_config=self.hf_config, + enable_sp=getattr(self, '_enable_sp', False), + ) labels: torch.Tensor = inputs.pop('labels', None) optimizer_config.accumulate_metrics(True) outputs = self.model(**inputs) @@ -434,7 +440,13 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' loss_instance = optimizer_config.loss_instance loss_require_logits = (hasattr(loss_instance, 'require_logits') and loss_instance.require_logits) - inputs: Dict[str, Any] = processor(inputs, sp_strategy=self.sp_strategy) + inputs: Dict[str, Any] = processor( + inputs, + sp_strategy=self.sp_strategy, + model=self.model, + hf_config=self.hf_config, + enable_sp=getattr(self, '_enable_sp', False), + ) labels = inputs.pop('labels', None) optimizer_config.accumulate_metrics(False) unwrapped_model = self.strategy.unwrap_model(self.model) diff --git a/src/twinkle/patch/qwen35_gdn_padding_free.py b/src/twinkle/patch/qwen35_gdn_padding_free.py new file mode 100644 index 00000000..598395c9 --- /dev/null +++ b/src/twinkle/patch/qwen35_gdn_padding_free.py @@ -0,0 +1,194 @@ +from typing import Optional + +import torch +from transformers.utils.import_utils import is_flash_linear_attention_available + +from twinkle.patch import Patch + + +def _is_qwen35_model(hf_config) -> bool: + return 'qwen3_5' in getattr(hf_config, 'model_type', '') + + +def _get_real_position_ids(position_ids: torch.Tensor) -> torch.Tensor: + return position_ids[0] if position_ids.dim() == 3 else position_ids + + +def _is_packed_position_ids(position_ids: torch.Tensor) -> bool: + if position_ids is None or not torch.is_tensor(position_ids): + return False + position_ids = _get_real_position_ids(position_ids) + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + if position_ids.dim() != 2: + return False + for i in range(position_ids.shape[0]): + row = position_ids[i] + if int((row == 0).sum()) > 1 and int((row == 1).sum()) > 1: + return True + return False + + +def _find_qwen35_classes(module: Optional[torch.nn.Module], hf_config, enable_sp: bool): + if module is None or enable_sp or not _is_qwen35_model(hf_config): + return None, None + try: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet + except Exception: + return None, None + if any(isinstance(submodule, Qwen3_5GatedDeltaNet) for submodule in module.modules()): + return Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet + return None, None + + +def _ensure_flash_linear_attention_available() -> None: + if is_flash_linear_attention_available(): + return + raise NotImplementedError( + 'Qwen3.5 padding_free/packed inputs require flash-linear-attention for GatedDeltaNet. ' + 'The native torch GatedDeltaNet implementation does not reset linear-attention state at packed ' + 'sequence boundaries. Please install flash-linear-attention or disable padding_free/packing.') + + +def _get_flash_linear_attention_kernels(): + _ensure_flash_linear_attention_available() + from fla.modules.convolution import causal_conv1d + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + return causal_conv1d, chunk_gated_delta_rule + + +def _run_with_gdn_conv_and_delta_rule_cu_seqlens( + mod: torch.nn.Module, + *, + cu_seqlens: torch.Tensor, + origin_forward, + forward_args, + forward_kwargs, +) -> torch.Tensor: + causal_conv1d, chunk_gated_delta_rule = _get_flash_linear_attention_kernels() + old_conv_fn = mod.causal_conv1d_fn + old_chunk_rule = mod.chunk_gated_delta_rule + + def causal_conv1d_wrapper(*args, **kwargs): + x = kwargs.pop('x') + output = causal_conv1d( + *args, + x=x.transpose(1, 2).contiguous(), + cu_seqlens=cu_seqlens.to(dtype=torch.int32, device=x.device), + **kwargs, + ) + if isinstance(output, tuple): + output = output[0] + return output.transpose(1, 2).contiguous() + + def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs): + kwargs['cu_seqlens'] = cu_seqlens.to(dtype=torch.int32, device=query.device) + return chunk_gated_delta_rule(query, key, value, **kwargs) + + mod.causal_conv1d_fn = causal_conv1d_wrapper + mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper + try: + return origin_forward(mod, *forward_args, **forward_kwargs) + finally: + mod.causal_conv1d_fn = old_conv_fn + mod.chunk_gated_delta_rule = old_chunk_rule + + +class Qwen35GatedDeltaNetPaddingFreePatch(Patch): + + def __call__(self, module, *args, **kwargs): + del args + Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet = _find_qwen35_classes( + module, + kwargs.get('hf_config'), + bool(kwargs.get('enable_sp', False)), + ) + if Qwen3_5DecoderLayer is None or Qwen3_5GatedDeltaNet is None: + return + if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): + return + module._twinkle_qwen35_padding_free_patched = True + + if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): + origin_decoder_forward = Qwen3_5DecoderLayer.forward + + def decoder_forward( + layer, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values=None, + cache_position: Optional[torch.Tensor] = None, + **extra_kwargs, + ): + if getattr(layer, 'layer_type', None) != 'linear_attention': + return origin_decoder_forward( + layer, + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **extra_kwargs, + ) + cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) + + residual = hidden_states + hidden_states = layer.input_layernorm(hidden_states) + hidden_states = layer.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = layer.post_attention_layernorm(hidden_states) + hidden_states = layer.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + Qwen3_5DecoderLayer.forward = decoder_forward + Qwen3_5DecoderLayer._twinkle_padding_free_cu_seqlens_patched = True + + if not getattr(Qwen3_5GatedDeltaNet, '_twinkle_padding_free_gdn_patched', False): + origin_forward = Qwen3_5GatedDeltaNet.forward + + def forward( + mod, + hidden_states: torch.Tensor, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + cu_seq_lens_q: Optional[torch.Tensor] = None, + **extra_kwargs, + ): + if cu_seq_lens_q is None: + return origin_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + **extra_kwargs, + ) + return _run_with_gdn_conv_and_delta_rule_cu_seqlens( + mod, + cu_seqlens=cu_seq_lens_q, + origin_forward=origin_forward, + forward_args=(hidden_states,), + forward_kwargs={ + 'cache_params': cache_params, + 'cache_position': cache_position, + 'attention_mask': attention_mask, + **extra_kwargs, + }, + ) + + Qwen3_5GatedDeltaNet.forward = forward + Qwen3_5GatedDeltaNet._twinkle_padding_free_gdn_patched = True diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index d7922a2b..6dc967a5 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -69,6 +69,7 @@ def __init__(self, self.collate_fn, self.to_transformers_dict, self.add_extra_padding_free_args, + self.prepare_transformers_padding_free_patch, self.drop_causal_4d_mask, self.split_cp, self.apply_transformers_sp, @@ -125,7 +126,13 @@ def apply_transformers_sp(self, inputs: List[InputFeature], **kwargs) -> List[In sp_strategy = kwargs.get('sp_strategy') if self.framework != 'transformers' or sp_strategy is None: return inputs - return [InputFeature(**sp_strategy.preprocess_inputs(dict(_input))) for _input in inputs] + padding_free = bool(self.padding_free or self._any_packing(inputs)) + results = [] + for _input in inputs: + payload = dict(_input) + payload['padding_free'] = padding_free + results.append(InputFeature(**sp_strategy.preprocess_inputs(payload))) + return results def postprocess_tensor_sp(self, inputs: Dict[str, Any], outputs: Dict[str, Any], **kwargs) -> tuple[Dict[str, Any], Dict[str, Any]]: @@ -277,11 +284,43 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], di def add_extra_padding_free_args(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]: for _inp in inputs: - padding_free = self.padding_free or self._any_packing([_inp]) + padding_free = bool(self.padding_free or self._any_packing([_inp])) if padding_free and self.framework == 'megatron': _inp['packed_seq_params'] = self._get_packed_seq_params(_inp['position_ids']) return inputs + def prepare_transformers_padding_free_patch(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]: + if self.framework != 'transformers': + return inputs + model = kwargs.get('model') + if model is None: + return inputs + padding_free = bool(self.padding_free or self._any_packing(inputs)) + if not padding_free or bool(kwargs.get('enable_sp', False)): + return inputs + + from twinkle.patch.qwen35_gdn_padding_free import Qwen35GatedDeltaNetPaddingFreePatch + from twinkle.patch import apply_patch + + apply_patch( + model, + Qwen35GatedDeltaNetPaddingFreePatch, + hf_config=kwargs.get('hf_config'), + enable_sp=False, + ) + if not getattr(model, '_twinkle_qwen35_padding_free_patched', False): + return inputs + + for _inp in inputs: + position_ids = _inp.get('position_ids') + if position_ids is None or not torch.is_tensor(position_ids): + continue + _inp['cu_seq_lens_q'] = self._get_packed_seq_params(position_ids).cu_seqlens_q.to( + dtype=torch.int32, + device=position_ids.device, + ) + return inputs + def drop_causal_4d_mask(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]: """On NPU, drop the generic 4D dense mask so MindSpeed can build its own compressed causal mask for FlashAttention.""" From f216a5d420d71b488d55b61c1f92043f6f26b1ae Mon Sep 17 00:00:00 2001 From: meichangsu1 <1484603386@qq.com> Date: Thu, 30 Apr 2026 18:16:10 +0800 Subject: [PATCH 2/2] refactor: simplify logits handling in forward methods Simplify the logic for returning logits in `forward` and `forward_only` methods by removing redundant `_outputs` copy and `logits` variable. The new logic directly modifies `outputs` and creates a single copy for return, reducing code complexity and potential bugs. --- .../model/transformers/transformers.py | 32 +++++++------------ src/twinkle/patch/qwen35_gdn_padding_free.py | 5 ++- src/twinkle/processor/base.py | 2 +- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index aedf8717..284d65fe 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -387,10 +387,9 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) + del logits outputs['past_key_values'] = None - _outputs = copy(outputs) - logits = outputs['logits'] - if not loss_require_logits: + if not (return_logits or loss_require_logits): outputs['logits'] = None inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy) inputs, outputs = processor.unpack_packed_sequences(inputs, outputs) @@ -398,13 +397,10 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec optimizer_config.train_status.outputs = outputs optimizer_config.train_status.forward_kwargs = kwargs optimizer_config.train_status.loss_value = outputs.get('aux_loss', 0) - if return_logits: - _outputs['logits'] = logits - else: - _outputs['logits'] = None - if not return_logits and not loss_require_logits: - del logits - return _outputs + return_outputs = copy(outputs) + if not return_logits: + return_outputs['logits'] = None + return return_outputs @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): @@ -463,10 +459,9 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T logits = outputs['logits'] logits.div_(temperature) outputs['logps'] = selective_log_softmax(logits, masked_labels) + del logits outputs['past_key_values'] = None - _outputs = copy(outputs) - logits = outputs['logits'] - if not loss_require_logits: + if not (return_logits or loss_require_logits): outputs['logits'] = None inputs, outputs = processor.postprocess_tensor_sp(inputs, outputs, sp_strategy=self.sp_strategy) inputs, outputs = processor.unpack_packed_sequences(inputs, outputs) @@ -474,13 +469,10 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T optimizer_config.eval_status.outputs = outputs optimizer_config.eval_status.forward_kwargs = kwargs optimizer_config.eval_status.loss_value = outputs.get('aux_loss', 0) - if return_logits: - _outputs['logits'] = logits - else: - _outputs['logits'] = None - if not return_logits and not loss_require_logits: - del logits - return _outputs + return_outputs = copy(outputs) + if not return_logits: + return_outputs['logits'] = None + return return_outputs @remote_function(collect='mean') def calculate_loss(self, **kwargs): diff --git a/src/twinkle/patch/qwen35_gdn_padding_free.py b/src/twinkle/patch/qwen35_gdn_padding_free.py index 598395c9..4cc155c3 100644 --- a/src/twinkle/patch/qwen35_gdn_padding_free.py +++ b/src/twinkle/patch/qwen35_gdn_padding_free.py @@ -1,7 +1,6 @@ -from typing import Optional - import torch from transformers.utils.import_utils import is_flash_linear_attention_available +from typing import Optional from twinkle.patch import Patch @@ -181,7 +180,7 @@ def forward( mod, cu_seqlens=cu_seq_lens_q, origin_forward=origin_forward, - forward_args=(hidden_states,), + forward_args=(hidden_states, ), forward_kwargs={ 'cache_params': cache_params, 'cache_position': cache_position, diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 6dc967a5..a7cc835c 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -299,8 +299,8 @@ def prepare_transformers_padding_free_patch(self, inputs: List[InputFeature], ** if not padding_free or bool(kwargs.get('enable_sp', False)): return inputs - from twinkle.patch.qwen35_gdn_padding_free import Qwen35GatedDeltaNetPaddingFreePatch from twinkle.patch import apply_patch + from twinkle.patch.qwen35_gdn_padding_free import Qwen35GatedDeltaNetPaddingFreePatch apply_patch( model,