Skip to content
Closed
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
084eb5c
added validating kwargs passed to nn.functional.cross_entropy
jasiecky Jan 13, 2026
51ad984
rollback
jasiecky Jan 13, 2026
763fabd
removed not allowed kwargs
jasiecky Jan 13, 2026
5225729
moved to inspect
jasiecky Jan 13, 2026
ee43e8f
added allowed_kwargs variable
jasiecky Jan 13, 2026
aa8d0ac
added tests
jasiecky Jan 14, 2026
3f7f007
reduplicated code
jasiecky Jan 14, 2026
29eb45b
Merge branch 'main' into fix/43242
jasiecky Jan 14, 2026
660f71e
Merge branch 'main' into fix/43242
jasiecky Jan 14, 2026
f6c526f
Merge branch 'main' into fix/43242
jasiecky Jan 15, 2026
57f3778
Merge branch 'main' into fix/43242
jasiecky Jan 16, 2026
0203462
added only supported parameters
jasiecky Jan 16, 2026
b022749
Merge branch 'fix/43242' of https://github.com/jasiecky/transformers …
jasiecky Jan 16, 2026
7b1e6af
removed unused imports
jasiecky Jan 16, 2026
fceddcf
changed label_smoothing to float
jasiecky Jan 16, 2026
699ff0d
added kwargs
jasiecky Jan 16, 2026
05fb96e
Merge branch 'main' into fix/43242
jasiecky Jan 16, 2026
d3838c9
Merge branch 'main' into fix/43242
jasiecky Jan 16, 2026
294d851
changed to _kwargs
jasiecky Jan 16, 2026
f651994
Merge branch 'fix/43242' of https://github.com/jasiecky/transformers …
jasiecky Jan 16, 2026
6e287b1
Merge branch 'main' into fix/43242
jasiecky Jan 19, 2026
2395023
Merge branch 'main' into fix/43242
jasiecky Feb 2, 2026
1306dac
Merge branch 'pr-43251' into merge-cluster-cluster-43240-3-2026042323…
evalstate Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ def fixed_cross_entropy(
target: torch.Tensor,
num_items_in_batch: torch.Tensor | None = None,
ignore_index: int = -100,
**kwargs,
weight: torch.Tensor | None = None,
label_smoothing: float = 0.0,
**_kwargs,
) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)

loss = nn.functional.cross_entropy(
source,
target,
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
weight=weight,
)

if reduction == "sum":
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
if torch.is_tensor(num_items_in_batch):
Expand Down