From 084eb5c65e2f5462655eb11d2a1b1582a57ee6e4 Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Tue, 13 Jan 2026 12:31:08 +0100 Subject: [PATCH 01/12] added validating kwargs passed to nn.functional.cross_entropy --- src/transformers/loss/loss_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..efb2d9af7686 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -32,16 +32,26 @@ def fixed_cross_entropy( ignore_index: int = -100, **kwargs, ) -> torch.Tensor: + allowed = {"weight", "size_average", "reduce", "label_smoothing"} + unknown = set(kwargs) - allowed + if unknown: + raise TypeError(f"Unexpected kwargs for nn.functional.cross_entropy: {unknown}") + reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + + loss = nn.functional.cross_entropy( + source, + target, + ignore_index=ignore_index, + **kwargs, + ) + if reduction == "sum": - # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.to(loss.device) loss = loss / num_items_in_batch return loss - def ForCausalLMLoss( logits, labels, From 51ad984540ea2d522326d32f3fef7b817056afad Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Tue, 13 Jan 2026 12:37:59 +0100 Subject: [PATCH 02/12] rollback --- src/transformers/loss/loss_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index efb2d9af7686..ccbb34809ef4 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -47,11 +47,13 @@ def fixed_cross_entropy( ) if reduction == "sum": + # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.to(loss.device) loss = loss / num_items_in_batch return loss + def ForCausalLMLoss( logits, labels, From 763fabd19bd6c03edd279d83a1e1b70b2d1feaab Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Tue, 13 Jan 2026 13:44:27 +0100 Subject: [PATCH 03/12] removed not allowed kwargs --- src/transformers/loss/loss_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index ccbb34809ef4..385c96be29e8 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -33,9 +33,6 @@ def fixed_cross_entropy( **kwargs, ) -> torch.Tensor: allowed = {"weight", "size_average", "reduce", "label_smoothing"} - unknown = set(kwargs) - allowed - if unknown: - raise TypeError(f"Unexpected kwargs for nn.functional.cross_entropy: {unknown}") reduction = "sum" if num_items_in_batch is not None else "mean" @@ -43,7 +40,8 @@ def fixed_cross_entropy( source, target, ignore_index=ignore_index, - **kwargs, + reduction=reduction, + **(kwargs & allowed), ) if reduction == "sum": From 52257290871541053f41a263893bedb06143d026 Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Tue, 13 Jan 2026 13:46:08 +0100 Subject: [PATCH 04/12] moved to inspect --- src/transformers/loss/loss_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 385c96be29e8..9a6f2184fdc0 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -13,6 +13,8 @@ # limitations under the License. +import inspect + import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, MSELoss @@ -32,16 +34,16 @@ def fixed_cross_entropy( ignore_index: int = -100, **kwargs, ) -> torch.Tensor: - allowed = {"weight", "size_average", "reduce", "label_smoothing"} - reduction = "sum" if num_items_in_batch is not None else "mean" + ce_params = inspect.signature(nn.functional.cross_entropy).parameters + loss = nn.functional.cross_entropy( source, target, ignore_index=ignore_index, - reduction=reduction, - **(kwargs & allowed), + reduction="sum" if num_items_in_batch else "mean", + **{k: v for k, v in kwargs.items() if k in ce_params}, ) if reduction == "sum": From ee43e8f9917670948a73ad5e1ddaf2ae7fe3789d Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Tue, 13 Jan 2026 13:47:01 +0100 Subject: [PATCH 05/12] added allowed_kwargs variable --- src/transformers/loss/loss_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 9a6f2184fdc0..810eb6769ace 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -35,15 +35,16 @@ def fixed_cross_entropy( **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - + ce_params = inspect.signature(nn.functional.cross_entropy).parameters + allowed_kwargs = {k: v for k, v in kwargs.items() if k in ce_params} loss = nn.functional.cross_entropy( source, target, ignore_index=ignore_index, reduction="sum" if num_items_in_batch else "mean", - **{k: v for k, v in kwargs.items() if k in ce_params}, + **allowed_kwargs, ) if reduction == "sum": From aa8d0acb5315b8e046f339409416f803d0cbd6b9 Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Wed, 14 Jan 2026 09:43:15 +0100 Subject: [PATCH 06/12] added tests --- tests/loss/test_loss_utils.py | 117 ++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/loss/test_loss_utils.py diff --git a/tests/loss/test_loss_utils.py b/tests/loss/test_loss_utils.py new file mode 100644 index 000000000000..f91e604641e0 --- /dev/null +++ b/tests/loss/test_loss_utils.py @@ -0,0 +1,117 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.testing_utils import require_torch +from transformers.utils import is_torch_available + +from transformers.loss import fixed_cross_entropy + +if is_torch_available(): + import torch + import torch.nn as nn + + +@require_torch +class FixedCrossEntropyTester(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + def test_ignores_unknown_kwargs(self): + source = torch.randn(4, 10, requires_grad=True) + target = torch.randint(0, 10, (4,)) + + loss = fixed_cross_entropy( + source, + target, + some_unknown_kwarg=123, + another_one="ignored", + ) + + expected = nn.functional.cross_entropy(source, target) + + self.assertTrue(torch.allclose(loss, expected)) + + def test_sum_reduction_and_tensor_normalization(self): + source = torch.randn(6, 5, requires_grad=True) + target = torch.randint(0, 5, (6,)) + num_items = torch.tensor(6) + + loss = fixed_cross_entropy( + source, + target, + num_items_in_batch=num_items, + ) + + expected = ( + nn.functional.cross_entropy(source, target, reduction="sum") + / num_items + ) + + self.assertTrue(torch.allclose(loss, expected)) + + def test_sum_reduction_and_int_normalization(self): + source = torch.randn(8, 3, requires_grad=True) + target = torch.randint(0, 3, (8,)) + num_items = 8 + + loss = fixed_cross_entropy( + source, + target, + num_items_in_batch=num_items, + ) + + expected = ( + nn.functional.cross_entropy(source, target, reduction="sum") + / num_items + ) + + self.assertTrue(torch.allclose(loss, expected)) + + def test_passes_valid_kwargs_only(self): + source = torch.randn(5, 4, requires_grad=True) + target = torch.randint(0, 4, (5,)) + + weight = torch.rand(4) + + loss = fixed_cross_entropy( + source, + target, + weight=weight, + label_smoothing=0.1, + invalid_kwarg=True, + ) + + expected = nn.functional.cross_entropy( + source, + target, + weight=weight, + label_smoothing=0.1, + ) + + self.assertTrue(torch.allclose(loss, expected)) + + def test_loss_device_matches_input(self): + source = torch.randn(4, 5) + target = torch.randint(0, 5, (4,)) + num_items = torch.tensor(4) + + loss = fixed_cross_entropy( + source, + target, + num_items_in_batch=num_items, + ) + + self.assertEqual(loss.device, source.device) From 3f7f00704031140653a417f2fa460eefaa44855c Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Wed, 14 Jan 2026 09:45:21 +0100 Subject: [PATCH 07/12] reduplicated code --- src/transformers/loss/loss_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 810eb6769ace..1003761b7888 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -35,7 +35,7 @@ def fixed_cross_entropy( **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - + ce_params = inspect.signature(nn.functional.cross_entropy).parameters allowed_kwargs = {k: v for k, v in kwargs.items() if k in ce_params} @@ -43,7 +43,7 @@ def fixed_cross_entropy( source, target, ignore_index=ignore_index, - reduction="sum" if num_items_in_batch else "mean", + reduction=reduction, **allowed_kwargs, ) From 0203462f99a78f19952668bf8c75b6a3aa49bf61 Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Fri, 16 Jan 2026 15:23:38 +0100 Subject: [PATCH 08/12] added only supported parameters --- src/transformers/loss/loss_utils.py | 9 +-- tests/loss/test_loss_utils.py | 117 ---------------------------- 2 files changed, 4 insertions(+), 122 deletions(-) delete mode 100644 tests/loss/test_loss_utils.py diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 1003761b7888..621d876d1a6b 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -32,19 +32,18 @@ def fixed_cross_entropy( target: torch.Tensor, num_items_in_batch: torch.Tensor | None = None, ignore_index: int = -100, - **kwargs, + label_smoothing: float | None = None, + weight: torch.Tensor | None = None, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" - ce_params = inspect.signature(nn.functional.cross_entropy).parameters - allowed_kwargs = {k: v for k, v in kwargs.items() if k in ce_params} - loss = nn.functional.cross_entropy( source, target, ignore_index=ignore_index, reduction=reduction, - **allowed_kwargs, + label_smoothing=label_smoothing, + weight=weight, ) if reduction == "sum": diff --git a/tests/loss/test_loss_utils.py b/tests/loss/test_loss_utils.py deleted file mode 100644 index f91e604641e0..000000000000 --- a/tests/loss/test_loss_utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2026 The HuggingFace Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from transformers.testing_utils import require_torch -from transformers.utils import is_torch_available - -from transformers.loss import fixed_cross_entropy - -if is_torch_available(): - import torch - import torch.nn as nn - - -@require_torch -class FixedCrossEntropyTester(unittest.TestCase): - def setUp(self): - torch.manual_seed(0) - - def test_ignores_unknown_kwargs(self): - source = torch.randn(4, 10, requires_grad=True) - target = torch.randint(0, 10, (4,)) - - loss = fixed_cross_entropy( - source, - target, - some_unknown_kwarg=123, - another_one="ignored", - ) - - expected = nn.functional.cross_entropy(source, target) - - self.assertTrue(torch.allclose(loss, expected)) - - def test_sum_reduction_and_tensor_normalization(self): - source = torch.randn(6, 5, requires_grad=True) - target = torch.randint(0, 5, (6,)) - num_items = torch.tensor(6) - - loss = fixed_cross_entropy( - source, - target, - num_items_in_batch=num_items, - ) - - expected = ( - nn.functional.cross_entropy(source, target, reduction="sum") - / num_items - ) - - self.assertTrue(torch.allclose(loss, expected)) - - def test_sum_reduction_and_int_normalization(self): - source = torch.randn(8, 3, requires_grad=True) - target = torch.randint(0, 3, (8,)) - num_items = 8 - - loss = fixed_cross_entropy( - source, - target, - num_items_in_batch=num_items, - ) - - expected = ( - nn.functional.cross_entropy(source, target, reduction="sum") - / num_items - ) - - self.assertTrue(torch.allclose(loss, expected)) - - def test_passes_valid_kwargs_only(self): - source = torch.randn(5, 4, requires_grad=True) - target = torch.randint(0, 4, (5,)) - - weight = torch.rand(4) - - loss = fixed_cross_entropy( - source, - target, - weight=weight, - label_smoothing=0.1, - invalid_kwarg=True, - ) - - expected = nn.functional.cross_entropy( - source, - target, - weight=weight, - label_smoothing=0.1, - ) - - self.assertTrue(torch.allclose(loss, expected)) - - def test_loss_device_matches_input(self): - source = torch.randn(4, 5) - target = torch.randint(0, 5, (4,)) - num_items = torch.tensor(4) - - loss = fixed_cross_entropy( - source, - target, - num_items_in_batch=num_items, - ) - - self.assertEqual(loss.device, source.device) From 7b1e6af2893c354c66d649100a74a68bc7eb4d1f Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Fri, 16 Jan 2026 15:26:39 +0100 Subject: [PATCH 09/12] removed unused imports --- src/transformers/loss/loss_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 621d876d1a6b..65393a87fdac 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -13,8 +13,6 @@ # limitations under the License. -import inspect - import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, MSELoss From fceddcfe332b3b42a1ee32123b043eb1c96c3216 Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Fri, 16 Jan 2026 16:00:44 +0100 Subject: [PATCH 10/12] changed label_smoothing to float --- src/transformers/loss/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 65393a87fdac..fce92dc50578 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -30,7 +30,7 @@ def fixed_cross_entropy( target: torch.Tensor, num_items_in_batch: torch.Tensor | None = None, ignore_index: int = -100, - label_smoothing: float | None = None, + label_smoothing: float = 0.0, weight: torch.Tensor | None = None, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" From 699ff0dd429463546eb3b0367750c49c84808bc3 Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Fri, 16 Jan 2026 16:29:54 +0100 Subject: [PATCH 11/12] added kwargs --- src/transformers/loss/loss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index fce92dc50578..82aafc008838 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -32,6 +32,7 @@ def fixed_cross_entropy( ignore_index: int = -100, label_smoothing: float = 0.0, weight: torch.Tensor | None = None, + **kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" From 294d8510eb9fb67c5170c09314af9dfa2bcd306e Mon Sep 17 00:00:00 2001 From: Jan Andrusikiewicz Date: Fri, 16 Jan 2026 18:12:08 +0100 Subject: [PATCH 12/12] changed to _kwargs --- src/transformers/loss/loss_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 82aafc008838..1d94cb53f9c1 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -30,9 +30,9 @@ def fixed_cross_entropy( target: torch.Tensor, num_items_in_batch: torch.Tensor | None = None, ignore_index: int = -100, - label_smoothing: float = 0.0, weight: torch.Tensor | None = None, - **kwargs, + label_smoothing: float = 0.0, + **_kwargs, ) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean"