Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,6 @@ def _should_bind_device_id_for_process_group(self, backend: str) -> bool:
# PG so Megatron's later Gloo DP groups stay decoupled on NPU.
return backend == 'nccl'

@staticmethod
def _drop_npu_causal_4d_mask(batch, unwrapped_model):
"""On NPU, drop the generic 4D dense mask so MindSpeed can build
its own compressed causal mask for FlashAttention."""
if Platform.device_prefix() != 'npu':
return
attention_mask = batch.get('attention_mask')
if (isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4
and getattr(unwrapped_model.config, 'attention_mask_type', None) == 'causal'):
batch['attention_mask'] = None

def _construct_default_optimizer_group(self):
return MegatronOptimizerGroup(
loss_instance=CrossEntropyLoss(reduction='sum'),
Expand Down Expand Up @@ -315,7 +304,13 @@ def forward_backward(self,
if micro_batch_size is None:
# Compatible with DPO
micro_batch_size = min(2, len(inputs))
inputs = processor(inputs, micro_batch_size=micro_batch_size, variable_seq_lengths=self.variable_seq_lengths)
unwrapped_model = self.strategy.unwrap_model(self.model)[0]
inputs = processor(
inputs,
micro_batch_size=micro_batch_size,
variable_seq_lengths=self.variable_seq_lengths,
attention_mask_type=getattr(unwrapped_model.config, 'attention_mask_type', None),
)
Comment thread
tastelikefeet marked this conversation as resolved.

# Get parallelism settings for sequence padding and splitting
cp_size = self.device_mesh.cp_world_size
Expand Down Expand Up @@ -379,7 +374,6 @@ def forward_step_func(data_iterator, model):
batch = next(data_iterator)
labels = batch.pop('labels', None)
unwrapped_model = self.strategy.unwrap_model([model])[0]
self._drop_npu_causal_4d_mask(batch, unwrapped_model)
if disable_lora and isinstance(unwrapped_model, PeftModel):
with unwrapped_model.disable_adapter():
output_tensor = model(**batch)
Expand Down
15 changes: 15 additions & 0 deletions src/twinkle/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self,
self.collate_fn,
self.to_transformers_dict,
self.add_extra_padding_free_args,
self.drop_causal_4d_mask,
self.split_cp,
self.prepare_outputs,
]
Expand Down Expand Up @@ -236,6 +237,20 @@ def add_extra_padding_free_args(self, inputs: List[InputFeature], **kwargs) -> L
_inp['packed_seq_params'] = self._get_packed_seq_params(_inp['position_ids'])
return inputs

def drop_causal_4d_mask(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]:
"""On NPU, drop the generic 4D dense mask so MindSpeed can build
its own compressed causal mask for FlashAttention."""
if Platform.device_prefix() != 'npu':
return inputs
attention_mask_type = kwargs.get('attention_mask_type')
if attention_mask_type != 'causal':
return inputs
for _inp in inputs:
attention_mask = _inp.get('attention_mask')
if isinstance(attention_mask, torch.Tensor) and attention_mask.dim() == 4:
_inp['attention_mask'] = None
return inputs

@staticmethod
def _pad_sequence(sequences, padding_value, padding_side):
if padding_side == 'right':
Expand Down
Loading