Skip to content

Nit about model_accepts_loss_kwargs for loss#35113

Closed
ArthurZucker wants to merge 1 commit intomainfrom
nit-ga-condition
Closed

Nit about model_accepts_loss_kwargs for loss#35113
ArthurZucker wants to merge 1 commit intomainfrom
nit-ga-condition

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

What does this PR do?

There was a typo in #34915, as tests were passing, did not pay attention to it. Thanks to @techkang for reporting

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

from transformers import Wav2Vec2Model
import inspect

sif = inspect.signature(Wav2Vec2Model.forward)
any(k.kind == inspect.Parameter.VAR_KEYWORD for k in sif.parameters.values())

Returns false. But when in the example torch, you have unwrapped_model.forward seems to have loss_kwargs
cc @muellerzr

@techkang
Copy link
Copy Markdown
Contributor

techkang commented Dec 6, 2024

This PR still confuses me as I mentioned in previes PR.
There are two ways to use num_items_in_batch to fix GA loss issue.

  1. Use num_items_in_batch in loss function defined by model. In this case, model_accepts_loss_kwargs is True.
  2. The model doesn't have loss function or user has self-defined loss function, which is compute_loss_func.
    However, this PR will make the second method invalid.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Hey!

Use num_items_in_batch in loss function defined by model. In this case, model_accepts_loss_kwargs is True.

No, model_accepts_loss_kwargs only depends on the forward_pass of the model, not the los function defined by the model. If the model accepts kwargs, we are gonna pass num_items_in_batch. Else we cannot.

The model doesn't have loss function or user has self-defined loss function, which is compute_loss_func.
However, this PR will make the second method invalid.

Even if user has a self defined loss function, if the forward pass does not support num_items_in_batch, we ought not to pass kwargs. I might be confusing with the trainer's compute_loss_func, and in that case will leave @muellerzr have a look

@muellerzr
Copy link
Copy Markdown
Contributor

@ArthurZucker I disagree with 2, it was explicitly designed that way to catch this in case the model doesn't have this figured out yet and the user can/should manually do it themselves instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants