generation/stopping_criteria: short-circuit StoppingCriteriaList when all sequences are done#45384
Closed
GitGlimpse895 wants to merge 3 commits intohuggingface:mainfrom
Closed
Conversation
78ed29b to
b400828
Compare
Member
|
Does this not create data-dependent control flow that requires a compiler break? |
Author
|
You're right that calling .all() on a tensor and using it in a Python if/break creates data-dependent control flow. Under torch.compile this would cause a graph break. However, StoppingCriteriaList.call is invoked from the Python-level generation loop in utils.py, which itself runs outside any compiled region — the model forward pass is compiled, but the decoding loop that calls stopping criteria is not. So no graph break is introduced in practice. That said, I'm happy to add a # pragma: no compile note or restructure if you'd prefer a different approach. |
b400828 to
e4613da
Compare
… all sequences are done
…its on full completion
e4613da to
5fe79c0
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
StoppingCriteriaList.__call__previously evaluated every registered criterionunconditionally on every generation step, even after
is_donewas alreadyTruefor all sequences in the batch. This adds a single
if is_done.all(): breakguardinside the loop to skip remaining criteria once the entire batch is finished.
This is semantically safe because OR-ing any further
Truevalues into anall-
Truetensor cannot change the result. The saving scales with(number_of_criteria − 1) × remaining_steps_after_completion, and is mostimpactful when
StopStringCriteria(which runs a full vocab embedding lookup viaF.embeddingeach call) appears after a cheaper criterion likeMaxLengthCriteria.A new test
test_list_criteria_short_circuits_when_all_doneis added to verifythe behaviour using a
CountingCriteriasentinel.Before submitting