diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index cb58e5585b75..0dbac1a08a02 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -15,6 +15,7 @@ Processing saving/loading class for common processors. """ +import bisect import copy import inspect import json @@ -1468,6 +1469,8 @@ def apply_chat_template( # It's a template string, render it directly chat_template = chat_template + is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast") + if kwargs.get("continue_final_message", False): if kwargs.get("add_generation_prompt", False): raise ValueError( @@ -1476,6 +1479,15 @@ def apply_chat_template( if kwargs.get("return_assistant_tokens_mask", False): raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") + if kwargs.get("return_assistant_tokens_mask", False): + if not is_tokenizers_fast: + raise ValueError( + "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. " + "If the error persists, open an issue to support a Fast tokenizer for your model." + ) + else: + kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries + # Fill sets of kwargs that should be used by different parts of template processed_kwargs = { "mm_load_kwargs": {}, @@ -1605,19 +1617,27 @@ def apply_chat_template( video_metadata=batch_video_metadata, **kwargs, ) + if return_dict: if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False): assistant_masks = [] + offset_mapping = out.pop("offset_mapping") input_ids = out["input_ids"] for i in range(len(input_ids)): current_mask = [0] * len(input_ids[i]) + offsets = offset_mapping[i] + offset_starts = [start for start, end in offsets] for assistant_start_char, assistant_end_char in generation_indices[i]: - start_token = out.char_to_token(i, assistant_start_char) - end_token = out.char_to_token(i, assistant_end_char - 1) - if start_token is None: + start_pos = bisect.bisect_left(offset_starts, assistant_start_char) + end_pos = bisect.bisect_left(offset_starts, assistant_end_char) + + if not ( + start_pos >= 0 + and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1] + ): # start_token is out of bounds maybe due to truncation. - break - for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): + continue + for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])): current_mask[token_id] = 1 assistant_masks.append(current_mask) out["assistant_masks"] = assistant_masks diff --git a/tests/models/csm/test_processor_csm.py b/tests/models/csm/test_processor_csm.py index dcd344d12036..2abb7eb2d66f 100644 --- a/tests/models/csm/test_processor_csm.py +++ b/tests/models/csm/test_processor_csm.py @@ -137,3 +137,8 @@ def test_apply_chat_template(self): [[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]] ) torch.testing.assert_close(input_ids, expected_ids) + + @require_torch + @unittest.skip("CSM doesn't need assistant masks as an audio generation model") + def test_apply_chat_template_assistant_mask(self): + pass diff --git a/tests/models/shieldgemma2/test_processing_shieldgemma2.py b/tests/models/shieldgemma2/test_processing_shieldgemma2.py index 86d316fd8895..0bbe17e65607 100644 --- a/tests/models/shieldgemma2/test_processing_shieldgemma2.py +++ b/tests/models/shieldgemma2/test_processing_shieldgemma2.py @@ -206,3 +206,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self): @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") def test_kwargs_overrides_default_image_processor_kwargs(self): pass + + @unittest.skip("ShieldGemma requires images in input, and fails in text-only processing") + def test_apply_chat_template_assistant_mask(self): + pass diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index f67a9e49f92c..2bb5b9c847b4 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -100,7 +100,11 @@ def get_component(self, attribute, **kwargs): assert attribute in self.processor_class.attributes component_class_name = getattr(self.processor_class, f"{attribute}_class") if isinstance(component_class_name, tuple): - component_class_name = component_class_name[0] + if attribute == "image_processor": + # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default) + component_class_name = component_class_name[0] + else: + component_class_name = component_class_name[-1] component_class = processor_class_from_name(component_class_name) component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa @@ -1149,3 +1153,77 @@ def test_chat_template_jinja_kwargs(self): ) expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n" self.assertEqual(formatted_prompt, expected_prompt) + + @require_torch + def test_apply_chat_template_assistant_mask(self): + processor = self.get_processor() + + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the capital of France?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The capital of France is Paris."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What about Italy?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The capital of Italy is Rome."}, + ], + }, + ] + ] + + dummy_template = ( + "{% for message in messages %}" + "{% if (message['role'] != 'assistant') %}" + "{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}" + "{% elif (message['role'] == 'assistant')%}" + "{{'<|special_start|>' + message['role'] + '\n'}}" + "{% generation %}" + "{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}" + "{% endgeneration %}" + "{% endif %}" + "{% endfor %}" + ) + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + return_assistant_tokens_mask=True, + chat_template=dummy_template, + ) + self.assertTrue("assistant_masks" in inputs) + self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"])) + + mask = inputs["assistant_masks"].bool() + assistant_ids = inputs["input_ids"][mask] + + assistant_text = ( + "The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n" + ) + + # Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids + # if we can't get identical text outputs + text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True) + ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist() + self.assertTrue(text_is_same or ids_is_same)