Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ def fixed_cross_entropy(
target: torch.Tensor,
num_items_in_batch: torch.Tensor | None = None,
ignore_index: int = -100,
**kwargs,
weight: torch.Tensor | None = None,
label_smoothing: float = 0.0,
**_kwargs,
Copy link
Copy Markdown
Contributor

@stas00 stas00 Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh? _?

I'd say remove it altogether, since it's being silently ignored and that's bad for the caller.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to remove kwargs? You accepted the code containing them;) If we don't use them the function isn't compatible with some parts of the repo so I changed it to _kwargs in order to explicitly show that kwargs're ignored.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea how renaming to _kwargs implies that it is ignored. When something is ignored it shouldn't be there.

As I shared earlier my opinion is that if **kwargs is in the API, they should be introspected and any unexpected keys should be asserted on. **kwargs are useful when a function is an intermediary and passes it on. In this case kwargs aren't passed on and thus shouldn't be there.

You accepted the code containing them;)

I'm not a current maintainer so my vote isn't binding. You want to engage current maintainers instead.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a naming convention, it doesn't imply anything indeed. Let's wait for the mainteners;)

@iamsernine @ArthurZucker @cyyever

) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)

loss = nn.functional.cross_entropy(
source,
target,
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
weight=weight,
)

if reduction == "sum":
# just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
if torch.is_tensor(num_items_in_batch):
Expand Down