diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..1d94cb53f9c1 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -30,10 +30,21 @@ def fixed_cross_entropy( target: torch.Tensor, num_items_in_batch: torch.Tensor | None = None, ignore_index: int = -100, - **kwargs, + weight: torch.Tensor | None = None, + label_smoothing: float = 0.0, + **_kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + + loss = nn.functional.cross_entropy( + source, + target, + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + weight=weight, + ) + if reduction == "sum": # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer if torch.is_tensor(num_items_in_batch):