From 7c722ba8a9403964211a479d3fa473b8c58f7d4f Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Jan 2026 13:51:47 +0000 Subject: [PATCH 1/2] Add supported kwargs to fixed_cross_entropy --- src/transformers/loss/loss_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..21259470e9ca 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -30,10 +30,12 @@ def fixed_cross_entropy( target: torch.Tensor, num_items_in_batch: torch.Tensor | None = None, ignore_index: int = -100, + weight: torch.Tensor | None = None, + label_smoothing: float = 0.0, **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, weight=weight, reduction=reduction, label_smoothing=label_smoothing) if reduction == "sum": # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer if torch.is_tensor(num_items_in_batch): From afb3f23b458f65ccdd3ce26a604389d6746aaacb Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Jan 2026 13:53:39 +0000 Subject: [PATCH 2/2] make style --- src/transformers/loss/loss_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 21259470e9ca..587fc78aeba2 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -35,7 +35,9 @@ def fixed_cross_entropy( **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, weight=weight, reduction=reduction, label_smoothing=label_smoothing) + loss = nn.functional.cross_entropy( + source, target, ignore_index=ignore_index, weight=weight, reduction=reduction, label_smoothing=label_smoothing + ) if reduction == "sum": # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer if torch.is_tensor(num_items_in_batch):