Skip to content

use nanmean for aggregating loss#44257

Open
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:nanmean-loss
Open

use nanmean for aggregating loss#44257
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:nanmean-loss

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

What does this PR do?

When post training using context parallelism, some processes may have their chunk of the sample input masked out leading to a NaN loss for that process. Using nanmean allows us to keep the real loss that isn't NaN

@SunMarc

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@winglian winglian requested a review from SunMarc February 24, 2026 14:56
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left a question

@@ -2054,7 +2101,7 @@ def _maybe_log_save_evaluate(
logs: dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()
tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have the following already. Not sure how we are still getting the nan

                if (
                    self.args.logging_nan_inf_filter
                    and not is_torch_xla_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    self._tr_loss += self._tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a fix we had to make last year when we added CP support, https://github.com/axolotl-ai-cloud/axolotl/pull/3033/changes, even though that code block above has been around for a couple of years. My assumption is that it has something to do with the multigpu gather.

Screenshot 2026-02-24 at 10 53 13 AM

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah most likely ! Thanks !

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a minor nit

@@ -2054,7 +2101,7 @@ def _maybe_log_save_evaluate(
logs: dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()
tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah most likely ! Thanks !

@@ -2054,7 +2101,7 @@ def _maybe_log_save_evaluate(
logs: dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()
tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a check self.args.logging_nan_inf_filter if we are using nanmean ?

Comment on lines +2797 to +2799
metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()
elif isinstance(all_losses, np.ndarray):
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@SunMarc SunMarc requested a review from qgallouedec February 25, 2026 16:19
Copy link
Copy Markdown
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, cc @kashif

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants