From 43aa11eb9395a3aa68a6a9832c5687349aefc5da Mon Sep 17 00:00:00 2001 From: GitGlimpse895 Date: Sun, 12 Apr 2026 15:40:07 +0530 Subject: [PATCH 1/3] generation/stopping_criteria: short-circuit StoppingCriteriaList when all sequences are done --- src/transformers/generation/stopping_criteria.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b57136f53416..c42f699e9e55 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -498,6 +498,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool) for criteria in self: is_done = is_done | criteria(input_ids, scores, **kwargs) + if is_done.all(): + break return is_done @property From 5fe79c05699cd057be7531f2d00577b07e373220 Mon Sep 17 00:00:00 2001 From: GitGlimpse895 Date: Sun, 12 Apr 2026 15:42:39 +0530 Subject: [PATCH 2/3] tests/generation: add test verifying StoppingCriteriaList short-circuits on full completion --- tests/generation/test_stopping_criteria.py | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index c120fe77882c..0b3a43ee5c4a 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -63,6 +63,29 @@ def test_list_criteria(self): input_ids, scores = self._get_tensors(10) self.assertTrue(all(criteria(input_ids, scores))) + def test_list_criteria_short_circuits_when_all_done(self): + """Verify that StoppingCriteriaList stops evaluating remaining criteria once all + sequences are already marked done, avoiding unnecessary computation.""" + call_count = 0 + + class CountingCriteria(StoppingCriteria): + def __call__(self, input_ids, scores, **kwargs): + nonlocal call_count + call_count += 1 + return torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device) + + # MaxLengthCriteria fires immediately (length == max_length), so CountingCriteria should be skipped + input_ids, scores = self._get_tensors(10) + criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=10), + CountingCriteria(), + ] + ) + result = criteria(input_ids, scores) + self.assertTrue(all(result)) + self.assertEqual(call_count, 0, "CountingCriteria should not be called when first criterion marks all done") + def test_max_length_criteria(self): criteria = MaxLengthCriteria(max_length=10) From b0fd56690bf726490c79bfb1d967a724c14f46a9 Mon Sep 17 00:00:00 2001 From: GitGlimpse895 Date: Thu, 16 Apr 2026 15:52:11 +0530 Subject: [PATCH 3/3] Update test_stopping_criteria.py --- tests/generation/test_stopping_criteria.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 0b3a43ee5c4a..9eebcef4e347 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -25,14 +25,15 @@ import torch from transformers.generation import ( - ConfidenceCriteria, - EosTokenCriteria, - MaxLengthCriteria, - MaxTimeCriteria, - StoppingCriteriaList, - StopStringCriteria, - validate_stopping_criteria, - ) + ConfidenceCriteria, + EosTokenCriteria, + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + StopStringCriteria, + validate_stopping_criteria, +) @require_torch