diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index dcece0ff..cdd3b655 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -155,17 +155,6 @@ def _should_bind_device_id_for_process_group(self, backend: str) -> bool: # PG so Megatron's later Gloo DP groups stay decoupled on NPU. return backend == 'nccl' - @staticmethod - def _drop_npu_causal_4d_mask(batch, unwrapped_model): - """On NPU, drop the generic 4D dense mask so MindSpeed can build - its own compressed causal mask for FlashAttention.""" - if Platform.device_prefix() != 'npu': - return - attention_mask = batch.get('attention_mask') - if (isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4 - and getattr(unwrapped_model.config, 'attention_mask_type', None) == 'causal'): - batch['attention_mask'] = None - def _construct_default_optimizer_group(self): return MegatronOptimizerGroup( loss_instance=CrossEntropyLoss(reduction='sum'), @@ -315,7 +304,13 @@ def forward_backward(self, if micro_batch_size is None: # Compatible with DPO micro_batch_size = min(2, len(inputs)) - inputs = processor(inputs, micro_batch_size=micro_batch_size, variable_seq_lengths=self.variable_seq_lengths) + unwrapped_model = self.strategy.unwrap_model(self.model)[0] + inputs = processor( + inputs, + micro_batch_size=micro_batch_size, + variable_seq_lengths=self.variable_seq_lengths, + attention_mask_type=getattr(unwrapped_model.config, 'attention_mask_type', None), + ) # Get parallelism settings for sequence padding and splitting cp_size = self.device_mesh.cp_world_size @@ -379,7 +374,6 @@ def forward_step_func(data_iterator, model): batch = next(data_iterator) labels = batch.pop('labels', None) unwrapped_model = self.strategy.unwrap_model([model])[0] - self._drop_npu_causal_4d_mask(batch, unwrapped_model) if disable_lora and isinstance(unwrapped_model, PeftModel): with unwrapped_model.disable_adapter(): output_tensor = model(**batch) diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index eb182bf5..5906aa13 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -65,6 +65,7 @@ def __init__(self, self.collate_fn, self.to_transformers_dict, self.add_extra_padding_free_args, + self.drop_causal_4d_mask, self.split_cp, self.prepare_outputs, ] @@ -236,6 +237,20 @@ def add_extra_padding_free_args(self, inputs: List[InputFeature], **kwargs) -> L _inp['packed_seq_params'] = self._get_packed_seq_params(_inp['position_ids']) return inputs + def drop_causal_4d_mask(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]: + """On NPU, drop the generic 4D dense mask so MindSpeed can build + its own compressed causal mask for FlashAttention.""" + if Platform.device_prefix() != 'npu': + return inputs + attention_mask_type = kwargs.get('attention_mask_type') + if attention_mask_type != 'causal': + return inputs + for _inp in inputs: + attention_mask = _inp.get('attention_mask') + if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4: + _inp['attention_mask'] = None + return inputs + @staticmethod def _pad_sequence(sequences, padding_value, padding_side): if padding_side == 'right':