Skip to content

Trainer: set skip_logits for loss-only eval when liger enabled#44981

Open
AkshajKashyap wants to merge 7 commits intohuggingface:mainfrom
AkshajKashyap:fix/gh-43039-skip-logits-eval
Open

Trainer: set skip_logits for loss-only eval when liger enabled#44981
AkshajKashyap wants to merge 7 commits intohuggingface:mainfrom
AkshajKashyap:fix/gh-43039-skip-logits-eval

Conversation

@AkshajKashyap
Copy link
Copy Markdown

Fixes #43039

What does this PR do?

When prediction_loss_only=True during evaluation and use_liger_kernel=True, Trainer.prediction_step now passes skip_logits=True to the model forward if the forward signature supports it and labels are present.

This avoids materializing logits during loss-only eval and enables fused loss paths for implementations that use skip_logits (for example, Liger kernel integrations), which can reduce memory usage during evaluation.

Implementation details

  • Injects skip_logits=True into inputs only when:
    • prediction_loss_only is true
    • use_liger_kernel is enabled
    • labels are present
    • model forward accepts a skip_logits parameter (checked via signature)
  • No behavior change when labels are missing, or when the model does not support skip_logits.

Tests

Added CPU-only unit tests:

  • test_trainer_sets_skip_logits_for_loss_only_eval_when_liger_enabled
  • test_trainer_does_not_set_skip_logits_when_no_labels_but_return_loss_true

Run with:
python -m pytest -q tests/trainer/test_skip_logits_eval.py -x

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Issue: When using the Liger Kernel, torch.nn.functional.cross_entropy is called #43039
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Tagging: @SunMarc

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you see memory gain with this ?

@AkshajKashyap
Copy link
Copy Markdown
Author

did you see memory gain with this ?

Yep, I measured a clear peak GPU memory reduction on my RTX 3050 (4GB) using a small Llama model with Liger enabled.

Benchmark setup:

  • Model: hf-internal-testing/tiny-random-LlamaForCausalLM
  • batch=1, seq_len=1024, fp16
  • Measured peak allocated GPU memory during loss-only eval

Results:

  • baseline (forced skip_logits=False): 324.7 MB peak
  • with this PR behavior (Trainer injects skip_logits=True when supported): 10.6 MB peak
  • delta: 314.0 MB (~96.7%)

This matches the intent of the fix: in loss-only eval, skipping logits avoids materializing the logits tensor and enables fused loss paths for implementations that use skip_logits (like Liger integrations). Savings should generally increase with longer sequence length / larger vocab.

Appreciate your reply, and really just trying to help and make myself useful.

@AkshajKashyap
Copy link
Copy Markdown
Author

did you see memory gain with this ?

Quick follow-up since it’s been about a week and CI is still green: is this change directionally OK, or would you prefer a different guard / placement?

Happy to adjust quickly if you want this behind an additional condition (for example, only for specific model types) or moved to a different part of the eval path. @SunMarc

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a few nits

Comment on lines +2937 to +2942
try:
forward_sig = inspect.signature(unwrap_model(model).forward)
if "skip_logits" in forward_sig.parameters:
inputs["skip_logits"] = True
except (TypeError, ValueError):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why try except ? in which cases we hit ValueError or TypeError ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is an arg from liger, so it should always be there no ? or not ?

@@ -0,0 +1,94 @@
import tempfile
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't create a new file, put it in an existing file. Also add one test only, it should be enough. Try to make the tests as simple as possible + small if possible

# Enable Liger fused loss path during eval when we only need the loss (no logits).
if (
prediction_loss_only
and getattr(self.args, "use_liger_kernel", False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need the getattr

if (
prediction_loss_only
and getattr(self.args, "use_liger_kernel", False)
and inputs.get("labels") is not None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like here, can can just put this piece of code in the correct place so that we don't have to check that https://github.com/huggingface/transformers/pull/45273/changes

Comment on lines +2935 to +2936
and "skip_logits" not in inputs
):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why skip_logits will be in inputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

When using the Liger Kernel, torch.nn.functional.cross_entropy is called

2 participants