Skip to content

Fix(43240): pass kwargs to nn.functional.cross_entropy#43251

Open
jasiecky wants to merge 22 commits intohuggingface:mainfrom
jasiecky:fix/43242
Open

Fix(43240): pass kwargs to nn.functional.cross_entropy#43251
jasiecky wants to merge 22 commits intohuggingface:mainfrom
jasiecky:fix/43242

Conversation

@jasiecky
Copy link
Copy Markdown

@jasiecky jasiecky commented Jan 13, 2026

What does this PR do?

The problem to be solved is the issue #43240. This PR implements passing weight and label_smoothing parameters of nn.functional.cross_entropy in fixed_cross_entropy function.

Fixes #43240

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@iamsernine @ArthurZucker @stas00 @cyyever

@jasiecky jasiecky changed the title Fix(43242): pass kwargs to nn.functional.cross_entropy Fix(43240): pass kwargs to nn.functional.cross_entropy Jan 13, 2026
Copy link
Copy Markdown
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a hard time matching the description of this PR to the proposed change. It looks like you want to pass additional kwargs which at the moment are dropped by this wrapper

looking at https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
it's weight and label_smoothing - is that what you're trying to pass?

And of course you need a test to support your PR, which would also self-document what you're trying to accomplish.

ensures consistent loss scaling by controlling the reduction mode when num_items_in_batch is provided

This is already done. Look at the first line of the function.

Comment thread src/transformers/loss/loss_utils.py Outdated
@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Jan 14, 2026

Thank you for adding the tests. I still don't understand what you're trying to solve with this PR.

Your PR description:

This PR adds validation for keyword arguments passed to cross_entropy and ensures consistent loss scaling by controlling the reduction mode when num_items_in_batch is provided.

The 2nd part is invalid, the pre-PR code already does that.

For the first part, what kwargs do you need to pass for your workload? Your tests exercise the 2 keys that have been ignored, but do you actually need them? It's been a long time since this function was added - perhaps fixed_ implies that it does a limit scope of things. I'm not sure. The trainer for example performs label_smoothing here

if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Jan 14, 2026

Thinking more about it and stepping away from this particular PR, one of the remaining issues in the HF Transformers API is that some keys in kwargs are silently dropped in some of the APIs. For example if you call from_config and set config.attn_implementation it will be silently ignored - it should assert and tell the user to pass attn_implementation as its own kwargs key in from_config and not part of the config object.

So what I suggest is that if you don't have a particular problem to solve in this PR, it can be made useful by asserting when unexpected kwargs are passed - those must not be silently dropped. The user needs to know when they are not using the API correctly.

The decision of whether fixed_cross_entropy should support weight and label_smoothing kwargs I will leave to the current maintainers.

@jasiecky
Copy link
Copy Markdown
Author

Thinking more about it and stepping away from this particular PR, one of the remaining issues in the HF Transformers API is that some keys in kwargs are silently dropped in some of the APIs. For example if you call from_config and set config.attn_implementation it will be silently ignored - it should assert and tell the user to pass attn_implementation as its own kwargs key in from_config and not part of the config object.

So what I suggest is that if you don't have a particular problem to solve in this PR, it can be made useful by asserting when unexpected kwargs are passed - those must not be silently dropped. The user needs to know when they are not using the API correctly.

The decision of whether fixed_cross_entropy should support weight and label_smoothing kwargs I will leave to the current maintainers.

The problem to be solved is the issue 43240. The thing is that currently we are not able to pass kwargs into nn.functional.cross_entropy so usage of weight and label_smoothing is impossible. If you think it would be a better solution I might change kwargs to these two parameters and pass them into the mentioned function.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Jan 15, 2026

That's helpful. Then that should be the description of the PR.

And you have a competitor here: #43254

@jasiecky
Copy link
Copy Markdown
Author

That's helpful. Then that should be the description of the PR.

And you have a competitor here: #43254

I updated the code and the description, ready for review.

Copy link
Copy Markdown
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

weight: torch.Tensor | None = None,
**kwargs,
label_smoothing: float = 0.0,
**_kwargs,
Copy link
Copy Markdown
Contributor

@stas00 stas00 Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh? _?

I'd say remove it altogether, since it's being silently ignored and that's bad for the caller.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to remove kwargs? You accepted the code containing them;) If we don't use them the function isn't compatible with some parts of the repo so I changed it to _kwargs in order to explicitly show that kwargs're ignored.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea how renaming to _kwargs implies that it is ignored. When something is ignored it shouldn't be there.

As I shared earlier my opinion is that if **kwargs is in the API, they should be introspected and any unexpected keys should be asserted on. **kwargs are useful when a function is an intermediary and passes it on. In this case kwargs aren't passed on and thus shouldn't be there.

You accepted the code containing them;)

I'm not a current maintainer so my vote isn't binding. You want to engage current maintainers instead.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a naming convention, it doesn't imply anything indeed. Let's wait for the mainteners;)

@iamsernine @ArthurZucker @cyyever

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43251&sha=6e287b

Copy link
Copy Markdown
Contributor

@iamsernine iamsernine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't know why i'm being mentionned here but lgtm

This was referenced Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

kwargs are not passed to loss calculation function.

3 participants