From e86a86fd087fb993d446308b352f4fb9636e2f3c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 24 Oct 2024 12:39:32 -0400 Subject: [PATCH] handle updated gradient accumulation fixes from transformers --- .../transformers/fused_linear_cross_entropy.py | 4 ++-- src/liger_kernel/transformers/model/llama.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 0e3331565..c69eca362 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -9,7 +9,7 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss): def __init__(self, *args, **kwargs): super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs) - def forward(self, lin_weight, _input, target, bias=None): + def forward(self, lin_weight, _input, target, bias=None, reduction=None): return LigerFusedLinearCrossEntropyFunction.apply( _input, lin_weight, @@ -17,5 +17,5 @@ def forward(self, lin_weight, _input, target, bias=None): bias, self.ignore_index, self.label_smoothing, - self.reduction, + reduction or self.reduction, ) diff --git a/src/liger_kernel/transformers/model/llama.py b/src/liger_kernel/transformers/model/llama.py index d0a5daee8..9ad4a057e 100644 --- a/src/liger_kernel/transformers/model/llama.py +++ b/src/liger_kernel/transformers/model/llama.py @@ -35,6 +35,8 @@ def lce_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy @@ -106,7 +108,12 @@ def lce_forward( shift_labels = shift_labels.view(-1) lce = LigerFusedLinearCrossEntropyLoss() - loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + lce_kwargs = {} + if "num_items_in_batch" in loss_kwargs: + lce_kwargs["reduction"] = "sum" + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels, **lce_kwargs) + if "num_items_in_batch" in loss_kwargs: + loss = loss / loss_kwargs["num_items_in_batch"] else: if self.config.pretraining_tp > 1: