From 61112a12b70867cd8007b04eef68b856544e4631 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 3 Jun 2025 09:19:37 +0200 Subject: [PATCH 1/4] messed up the git history, squash commits --- src/transformers/processing_utils.py | 19 ++++++-- tests/test_processing_common.py | 68 ++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 8ee7ce5adbb6..917101e32eb2 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -15,6 +15,7 @@ Processing saving/loading class for common processors. """ +import bisect import copy import inspect import json @@ -1602,21 +1603,29 @@ def apply_chat_template( images=batch_images if batch_images else None, videos=batch_videos if batch_videos else None, audio=batch_audios if batch_audios else None, + return_offsets_mapping=True, **kwargs, ) + if return_dict: if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False): assistant_masks = [] input_ids = out["input_ids"] for i in range(len(input_ids)): current_mask = [0] * len(input_ids[i]) + offsets = out["offset_mapping"][i] + offset_starts = [start for start, end in offsets] for assistant_start_char, assistant_end_char in generation_indices[i]: - start_token = out.char_to_token(i, assistant_start_char) - end_token = out.char_to_token(i, assistant_end_char - 1) - if start_token is None: + start_pos = bisect.bisect_right(offset_starts, assistant_start_char) - 1 + end_pos = bisect.bisect_right(offset_starts, assistant_end_char) - 1 + + if not ( + start_pos >= 0 + and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1] + ): # start_token is out of bounds maybe due to truncation. - break - for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): + continue + for token_id in range(start_pos, end_pos + 1 if end_pos else len(input_ids[i])): current_mask[token_id] = 1 assistant_masks.append(current_mask) out["assistant_masks"] = assistant_masks diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 610e109b56fd..4943b18094b2 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -1152,3 +1152,71 @@ def test_chat_template_audio_from_video(self): self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1 self.assertEqual(len(out_dict[self.audio_input_name]), 1) # 1 audio in the conversation self.assertEqual(len(out_dict[self.videos_input_name]), 1) # 1 video in the conversation + + @require_torch + def test_apply_chat_template_assistant_mask(self): + processor = self.get_processor() + + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the capital of France?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The capital of France is Paris."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What about Italy?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The capital of Italy is Rome."}, + ], + }, + ] + ] + + dummy_template = ( + "{% for message in messages %}" + "{% if (message['role'] != 'assistant') %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|im_end|>' + '\n'}}" + "{% elif (message['role'] == 'assistant')%}" + "{{'<|im_start|>' + message['role'] + '\n'}}" + "{% generation %}" + "{{message['content'][0]['text'] + '<|im_end|>'}}" + "{% endgeneration %}" + "{{'\n'}}" + "{% endif %}" + "{% endfor %}" + ) + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + return_assistant_tokens_mask=True, + chat_template=dummy_template, + ) + self.assertTrue("assistant_masks" in inputs) + self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"])) + + mask = inputs["assistant_masks"].bool() + assistant_ids = inputs["input_ids"][mask] + + assistant_text = "The capital of France is Paris.<|im_end|>\nThe capital of Italy is Rome.<|im_end|>\n" + self.assertEqual(assistant_text, processor.decode(assistant_ids)) From 5ed3442f8fe3e56e3b3670aad9f78fa346e2f452 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 12 Jun 2025 11:08:55 +0200 Subject: [PATCH 2/4] raise error if slow and refine tests --- src/transformers/processing_utils.py | 15 +++++++++++++-- tests/test_processing_common.py | 6 +++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 917101e32eb2..b98fbbac08d8 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1461,6 +1461,8 @@ def apply_chat_template( # It's a template string, render it directly chat_template = chat_template + is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast") + if kwargs.get("continue_final_message", False): if kwargs.get("add_generation_prompt", False): raise ValueError( @@ -1469,6 +1471,15 @@ def apply_chat_template( if kwargs.get("return_assistant_tokens_mask", False): raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") + if kwargs.get("return_assistant_tokens_mask", False): + if not is_tokenizers_fast: + raise ValueError( + "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. " + "If the error persists, open an issue to support a Fast tokenizer for your model." + ) + else: + kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries + # Fill sets of kwargs that should be used by different parts of template processed_kwargs = { "mm_load_kwargs": {}, @@ -1603,17 +1614,17 @@ def apply_chat_template( images=batch_images if batch_images else None, videos=batch_videos if batch_videos else None, audio=batch_audios if batch_audios else None, - return_offsets_mapping=True, **kwargs, ) if return_dict: if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False): assistant_masks = [] + offset_mapping = out.pop("offset_mapping") input_ids = out["input_ids"] for i in range(len(input_ids)): current_mask = [0] * len(input_ids[i]) - offsets = out["offset_mapping"][i] + offsets = offset_mapping[i] offset_starts = [start for start, end in offsets] for assistant_start_char, assistant_end_char in generation_indices[i]: start_pos = bisect.bisect_right(offset_starts, assistant_start_char) - 1 diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 4943b18094b2..87c73042112a 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -100,7 +100,11 @@ def get_component(self, attribute, **kwargs): assert attribute in self.processor_class.attributes component_class_name = getattr(self.processor_class, f"{attribute}_class") if isinstance(component_class_name, tuple): - component_class_name = component_class_name[0] + if attribute == "image_processor": + # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default) + component_class_name = component_class_name[0] + else: + component_class_name = component_class_name[-1] component_class = processor_class_from_name(component_class_name) component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa From f5623a7346e4c799208e2c7146629434658d2cf9 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 12 Jun 2025 11:12:18 +0200 Subject: [PATCH 3/4] index was off by one --- src/transformers/processing_utils.py | 2 +- tests/test_processing_common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b98fbbac08d8..2eb33713bfe4 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1636,7 +1636,7 @@ def apply_chat_template( ): # start_token is out of bounds maybe due to truncation. continue - for token_id in range(start_pos, end_pos + 1 if end_pos else len(input_ids[i])): + for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])): current_mask[token_id] = 1 assistant_masks.append(current_mask) out["assistant_masks"] = assistant_masks diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 87c73042112a..d5d0b219c3dd 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -1222,5 +1222,5 @@ def test_apply_chat_template_assistant_mask(self): mask = inputs["assistant_masks"].bool() assistant_ids = inputs["input_ids"][mask] - assistant_text = "The capital of France is Paris.<|im_end|>\nThe capital of Italy is Rome.<|im_end|>\n" + assistant_text = "The capital of France is Paris.<|im_end|>The capital of Italy is Rome.<|im_end|>" self.assertEqual(assistant_text, processor.decode(assistant_ids)) From 025afdb8b8f6970c916eec8a399af49c5f986d0e Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 12 Jun 2025 12:18:07 +0200 Subject: [PATCH 4/4] fix the test --- src/transformers/processing_utils.py | 4 ++-- tests/models/csm/test_processor_csm.py | 5 +++++ .../test_processing_shieldgemma2.py | 4 ++++ tests/test_processing_common.py | 18 ++++++++++++------ 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 2eb33713bfe4..34290f6fa12d 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1627,8 +1627,8 @@ def apply_chat_template( offsets = offset_mapping[i] offset_starts = [start for start, end in offsets] for assistant_start_char, assistant_end_char in generation_indices[i]: - start_pos = bisect.bisect_right(offset_starts, assistant_start_char) - 1 - end_pos = bisect.bisect_right(offset_starts, assistant_end_char) - 1 + start_pos = bisect.bisect_left(offset_starts, assistant_start_char) + end_pos = bisect.bisect_left(offset_starts, assistant_end_char) if not ( start_pos >= 0 diff --git a/tests/models/csm/test_processor_csm.py b/tests/models/csm/test_processor_csm.py index dcd344d12036..2abb7eb2d66f 100644 --- a/tests/models/csm/test_processor_csm.py +++ b/tests/models/csm/test_processor_csm.py @@ -137,3 +137,8 @@ def test_apply_chat_template(self): [[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]] ) torch.testing.assert_close(input_ids, expected_ids) + + @require_torch + @unittest.skip("CSM doesn't need assistant masks as an audio generation model") + def test_apply_chat_template_assistant_mask(self): + pass diff --git a/tests/models/shieldgemma2/test_processing_shieldgemma2.py b/tests/models/shieldgemma2/test_processing_shieldgemma2.py index 86d316fd8895..0bbe17e65607 100644 --- a/tests/models/shieldgemma2/test_processing_shieldgemma2.py +++ b/tests/models/shieldgemma2/test_processing_shieldgemma2.py @@ -206,3 +206,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self): @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") def test_kwargs_overrides_default_image_processor_kwargs(self): pass + + @unittest.skip("ShieldGemma requires images in input, and fails in text-only processing") + def test_apply_chat_template_assistant_mask(self): + pass diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index d5d0b219c3dd..9a08b1470b9c 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -1196,13 +1196,12 @@ def test_apply_chat_template_assistant_mask(self): dummy_template = ( "{% for message in messages %}" "{% if (message['role'] != 'assistant') %}" - "{{'<|im_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|im_end|>' + '\n'}}" + "{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}" "{% elif (message['role'] == 'assistant')%}" - "{{'<|im_start|>' + message['role'] + '\n'}}" + "{{'<|special_start|>' + message['role'] + '\n'}}" "{% generation %}" - "{{message['content'][0]['text'] + '<|im_end|>'}}" + "{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}" "{% endgeneration %}" - "{{'\n'}}" "{% endif %}" "{% endfor %}" ) @@ -1222,5 +1221,12 @@ def test_apply_chat_template_assistant_mask(self): mask = inputs["assistant_masks"].bool() assistant_ids = inputs["input_ids"][mask] - assistant_text = "The capital of France is Paris.<|im_end|>The capital of Italy is Rome.<|im_end|>" - self.assertEqual(assistant_text, processor.decode(assistant_ids)) + assistant_text = ( + "The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n" + ) + + # Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids + # if we can't get identical text outputs + text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True) + ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist() + self.assertTrue(text_is_same or ids_is_same)