diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..587fc78aeba2 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -30,10 +30,14 @@ def fixed_cross_entropy( target: torch.Tensor, num_items_in_batch: torch.Tensor | None = None, ignore_index: int = -100, + weight: torch.Tensor | None = None, + label_smoothing: float = 0.0, **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + loss = nn.functional.cross_entropy( + source, target, ignore_index=ignore_index, weight=weight, reduction=reduction, label_smoothing=label_smoothing + ) if reduction == "sum": # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer if torch.is_tensor(num_items_in_batch):