diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 722fa446a..3816f66f4 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -201,24 +201,33 @@ def tokenize_trajectory( set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids) ) sentinal_token = tokenizer.decode(sentinal_token_id) + token_template_messages: list[dict[str, Any]] = [] + for original, message in zip(messages_and_choices, messages): + trainable_assistant = ( + not isinstance(original, dict) and original.logprobs is not None + ) or ( + allow_training_without_logprobs + and isinstance(original, dict) + and original.get("role") == "assistant" + ) + if trainable_assistant: + token_template_messages.append( + { + "role": "assistant", + "content": sentinal_token, + **( + {"tool_calls": message.get("tool_calls")} # type: ignore[call-overload] + if message.get("tool_calls") # type: ignore[call-overload] + else {} + ), + } + ) + else: + token_template_messages.append(cast(dict[str, Any], message)) token_ids = cast( list[int], tokenizer.apply_chat_template( - cast( - list[dict], - [ - ( - message_or_choice - if isinstance(message_or_choice, dict) - and not message_or_choice["role"] == "assistant" - else { - "role": "assistant", - "content": sentinal_token, - } - ) - for message_or_choice in messages_and_choices - ], - ), + cast(list[dict], token_template_messages), tools=tools, # type: ignore continue_final_message=True, ), @@ -226,7 +235,12 @@ def tokenize_trajectory( assistant_mask: list[int] = [0] * len(token_ids) logprobs = [float("nan")] * len(token_ids) for message in messages_and_choices: - if isinstance(message, dict) and not message["role"] == "assistant": + if isinstance(message, dict): + if message["role"] != "assistant": + continue + if not allow_training_without_logprobs: + continue + elif message.logprobs is None and not allow_training_without_logprobs: continue start = token_ids.index(sentinal_token_id) end = start + 1 @@ -235,6 +249,11 @@ def tokenize_trajectory( except IndexError: end_token_id = None if isinstance(message, dict): + if message.get("tool_calls"): + raise ValueError( + "Assistant message has tool_calls but is being tokenized " + "via tokenizer.encode(content). This path ignores tool calls." + ) content = message.get("content") assert isinstance(content, str) content_token_ids = tokenizer.encode(