diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 60cc30a8..2e2fc64c 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -134,6 +134,12 @@ def calculate_metrics(self, is_training): '.mlp.gate.weight', '.mlp.gate.bias', '.mlp.gate.e_score_correction_bias', + '.in_proj_qkv.weight', + '.in_proj_z.weight', + '.in_proj_a.weight', + '.in_proj_b.weight', + '.out_proj.weight', + '.conv1d.weight', ] diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py index f13eb056..4a721a2a 100644 --- a/src/twinkle/utils/torch_utils.py +++ b/src/twinkle/utils/torch_utils.py @@ -72,6 +72,14 @@ def selective_log_softmax(logits, index) -> 'torch.Tensor': import torch import torch.nn.functional as F + try: + from megatron.core import parallel_state as mpu + if mpu.get_tensor_model_parallel_world_size() > 1: + # clone to avoid modifying the original logits + return _vocab_parallel_selective_log_softmax(logits.clone(), index) + except Exception: + pass + if logits.dtype in [torch.float32, torch.float64]: selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption