diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b57136f53416..c42f699e9e55 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -498,6 +498,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool) for criteria in self: is_done = is_done | criteria(input_ids, scores, **kwargs) + if is_done.all(): + break return is_done @property diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index c120fe77882c..9eebcef4e347 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -25,14 +25,15 @@ import torch from transformers.generation import ( - ConfidenceCriteria, - EosTokenCriteria, - MaxLengthCriteria, - MaxTimeCriteria, - StoppingCriteriaList, - StopStringCriteria, - validate_stopping_criteria, - ) + ConfidenceCriteria, + EosTokenCriteria, + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + StopStringCriteria, + validate_stopping_criteria, +) @require_torch @@ -63,6 +64,29 @@ def test_list_criteria(self): input_ids, scores = self._get_tensors(10) self.assertTrue(all(criteria(input_ids, scores))) + def test_list_criteria_short_circuits_when_all_done(self): + """Verify that StoppingCriteriaList stops evaluating remaining criteria once all + sequences are already marked done, avoiding unnecessary computation.""" + call_count = 0 + + class CountingCriteria(StoppingCriteria): + def __call__(self, input_ids, scores, **kwargs): + nonlocal call_count + call_count += 1 + return torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device) + + # MaxLengthCriteria fires immediately (length == max_length), so CountingCriteria should be skipped + input_ids, scores = self._get_tensors(10) + criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=10), + CountingCriteria(), + ] + ) + result = criteria(input_ids, scores) + self.assertTrue(all(result)) + self.assertEqual(call_count, 0, "CountingCriteria should not be called when first criterion marks all done") + def test_max_length_criteria(self): criteria = MaxLengthCriteria(max_length=10)