Trainer: set skip_logits for loss-only eval when liger enabled#44981
Trainer: set skip_logits for loss-only eval when liger enabled#44981AkshajKashyap wants to merge 7 commits intohuggingface:mainfrom
Conversation
SunMarc
left a comment
There was a problem hiding this comment.
did you see memory gain with this ?
Yep, I measured a clear peak GPU memory reduction on my RTX 3050 (4GB) using a small Llama model with Liger enabled. Benchmark setup:
Results:
This matches the intent of the fix: in loss-only eval, skipping logits avoids materializing the logits tensor and enables fused loss paths for implementations that use skip_logits (like Liger integrations). Savings should generally increase with longer sequence length / larger vocab. Appreciate your reply, and really just trying to help and make myself useful. |
Quick follow-up since it’s been about a week and CI is still green: is this change directionally OK, or would you prefer a different guard / placement? Happy to adjust quickly if you want this behind an additional condition (for example, only for specific model types) or moved to a different part of the eval path. @SunMarc |
| try: | ||
| forward_sig = inspect.signature(unwrap_model(model).forward) | ||
| if "skip_logits" in forward_sig.parameters: | ||
| inputs["skip_logits"] = True | ||
| except (TypeError, ValueError): | ||
| pass |
There was a problem hiding this comment.
why try except ? in which cases we hit ValueError or TypeError ?
There was a problem hiding this comment.
Also this is an arg from liger, so it should always be there no ? or not ?
| @@ -0,0 +1,94 @@ | |||
| import tempfile | |||
There was a problem hiding this comment.
don't create a new file, put it in an existing file. Also add one test only, it should be enough. Try to make the tests as simple as possible + small if possible
| # Enable Liger fused loss path during eval when we only need the loss (no logits). | ||
| if ( | ||
| prediction_loss_only | ||
| and getattr(self.args, "use_liger_kernel", False) |
| if ( | ||
| prediction_loss_only | ||
| and getattr(self.args, "use_liger_kernel", False) | ||
| and inputs.get("labels") is not None |
There was a problem hiding this comment.
like here, can can just put this piece of code in the correct place so that we don't have to check that https://github.com/huggingface/transformers/pull/45273/changes
| and "skip_logits" not in inputs | ||
| ): |
There was a problem hiding this comment.
not sure why skip_logits will be in inputs
Fixes #43039
What does this PR do?
When
prediction_loss_only=Trueduring evaluation anduse_liger_kernel=True,Trainer.prediction_stepnow passesskip_logits=Trueto the model forward if the forward signature supports it and labels are present.This avoids materializing logits during loss-only eval and enables fused loss paths for implementations that use
skip_logits(for example, Liger kernel integrations), which can reduce memory usage during evaluation.Implementation details
skip_logits=Trueintoinputsonly when:prediction_loss_onlyis trueuse_liger_kernelis enabledlabelsare presentskip_logitsparameter (checked via signature)skip_logits.Tests
Added CPU-only unit tests:
test_trainer_sets_skip_logits_for_loss_only_eval_when_liger_enabledtest_trainer_does_not_set_skip_logits_when_no_labels_but_return_loss_trueRun with:
python -m pytest -q tests/trainer/test_skip_logits_eval.py -xCode Agent Policy
Before submitting
Who can review?
Tagging: @SunMarc