Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Processing saving/loading class for common processors.
"""

import bisect
import copy
import inspect
import json
Expand Down Expand Up @@ -1468,6 +1469,8 @@ def apply_chat_template(
# It's a template string, render it directly
chat_template = chat_template

is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast")

Comment on lines +1472 to +1473
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that tokenizer's never checks this, probably because all new LLMs support fast tokenizers. Though users can force set use_fast=False for some reasons and the error message in that case is not informative

Should I add the check on tokenizer's apply_chat_template as well, WDYT?

if kwargs.get("continue_final_message", False):
if kwargs.get("add_generation_prompt", False):
raise ValueError(
Expand All @@ -1476,6 +1479,15 @@ def apply_chat_template(
if kwargs.get("return_assistant_tokens_mask", False):
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")

if kwargs.get("return_assistant_tokens_mask", False):
if not is_tokenizers_fast:
raise ValueError(
"`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. "
"If the error persists, open an issue to support a Fast tokenizer for your model."
)
else:
kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries

# Fill sets of kwargs that should be used by different parts of template
processed_kwargs = {
"mm_load_kwargs": {},
Expand Down Expand Up @@ -1605,19 +1617,27 @@ def apply_chat_template(
video_metadata=batch_video_metadata,
**kwargs,
)

if return_dict:
if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
assistant_masks = []
offset_mapping = out.pop("offset_mapping")
input_ids = out["input_ids"]
for i in range(len(input_ids)):
current_mask = [0] * len(input_ids[i])
offsets = offset_mapping[i]
offset_starts = [start for start, end in offsets]
for assistant_start_char, assistant_end_char in generation_indices[i]:
start_token = out.char_to_token(i, assistant_start_char)
end_token = out.char_to_token(i, assistant_end_char - 1)
if start_token is None:
start_pos = bisect.bisect_left(offset_starts, assistant_start_char)
end_pos = bisect.bisect_left(offset_starts, assistant_end_char)

if not (
start_pos >= 0
and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1]
):
# start_token is out of bounds maybe due to truncation.
break
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
continue
for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks
Expand Down
5 changes: 5 additions & 0 deletions tests/models/csm/test_processor_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,8 @@ def test_apply_chat_template(self):
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
)
torch.testing.assert_close(input_ids, expected_ids)

@require_torch
@unittest.skip("CSM doesn't need assistant masks as an audio generation model")
def test_apply_chat_template_assistant_mask(self):
pass
4 changes: 4 additions & 0 deletions tests/models/shieldgemma2/test_processing_shieldgemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self):
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_kwargs_overrides_default_image_processor_kwargs(self):
pass

@unittest.skip("ShieldGemma requires images in input, and fails in text-only processing")
def test_apply_chat_template_assistant_mask(self):
pass
80 changes: 79 additions & 1 deletion tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ def get_component(self, attribute, **kwargs):
assert attribute in self.processor_class.attributes
component_class_name = getattr(self.processor_class, f"{attribute}_class")
if isinstance(component_class_name, tuple):
component_class_name = component_class_name[0]
if attribute == "image_processor":
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
component_class_name = component_class_name[0]
else:
component_class_name = component_class_name[-1]

component_class = processor_class_from_name(component_class_name)
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
Expand Down Expand Up @@ -1149,3 +1153,77 @@ def test_chat_template_jinja_kwargs(self):
)
expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n"
self.assertEqual(formatted_prompt, expected_prompt)

@require_torch
def test_apply_chat_template_assistant_mask(self):
processor = self.get_processor()

if processor.chat_template is None:
self.skipTest("Processor has no chat template")

messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the capital of France?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The capital of France is Paris."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "What about Italy?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The capital of Italy is Rome."},
],
},
]
]

dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|special_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
"{% endgeneration %}"
"{% endif %}"
"{% endfor %}"
)

inputs = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
return_assistant_tokens_mask=True,
chat_template=dummy_template,
)
self.assertTrue("assistant_masks" in inputs)
self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"]))

mask = inputs["assistant_masks"].bool()
assistant_ids = inputs["input_ids"][mask]

assistant_text = (
"The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n"
)

# Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids
# if we can't get identical text outputs
text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True)
ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist()
self.assertTrue(text_is_same or ids_is_same)