From 2d329a2ff9add4d380fb1a1cee2bce705259f43a Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Wed, 21 Jan 2026 18:50:48 +0000 Subject: [PATCH 1/2] fix: preserve tool-call context in tokenization Only splice trainable assistant spans and keep tool_calls in the template; error if tool_calls would be dropped. --- src/art/preprocessing/tokenize.py | 52 +++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 722fa446a..e3cb0a4cc 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -201,24 +201,34 @@ def tokenize_trajectory( set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids) ) sentinal_token = tokenizer.decode(sentinal_token_id) + token_template_messages: list[dict[str, Any]] = [] + for original, message in zip(messages_and_choices, messages): + trainable_assistant = ( + (not isinstance(original, dict) and original.logprobs is not None) + or ( + allow_training_without_logprobs + and isinstance(original, dict) + and original.get("role") == "assistant" + ) + ) + if trainable_assistant: + token_template_messages.append( + { + "role": "assistant", + "content": sentinal_token, + **( + {"tool_calls": message["tool_calls"]} + if message.get("tool_calls") + else {} + ), + } + ) + else: + token_template_messages.append(message) token_ids = cast( list[int], tokenizer.apply_chat_template( - cast( - list[dict], - [ - ( - message_or_choice - if isinstance(message_or_choice, dict) - and not message_or_choice["role"] == "assistant" - else { - "role": "assistant", - "content": sentinal_token, - } - ) - for message_or_choice in messages_and_choices - ], - ), + cast(list[dict], token_template_messages), tools=tools, # type: ignore continue_final_message=True, ), @@ -226,7 +236,12 @@ def tokenize_trajectory( assistant_mask: list[int] = [0] * len(token_ids) logprobs = [float("nan")] * len(token_ids) for message in messages_and_choices: - if isinstance(message, dict) and not message["role"] == "assistant": + if isinstance(message, dict): + if message["role"] != "assistant": + continue + if not allow_training_without_logprobs: + continue + elif message.logprobs is None and not allow_training_without_logprobs: continue start = token_ids.index(sentinal_token_id) end = start + 1 @@ -235,6 +250,11 @@ def tokenize_trajectory( except IndexError: end_token_id = None if isinstance(message, dict): + if message.get("tool_calls"): + raise ValueError( + "Assistant message has tool_calls but is being tokenized " + "via tokenizer.encode(content). This path ignores tool calls." + ) content = message.get("content") assert isinstance(content, str) content_token_ids = tokenizer.encode( From 25ffbaf4b093da1bbbabab5f6fc0840ee7a22900 Mon Sep 17 00:00:00 2001 From: Cursor Bot Date: Wed, 21 Jan 2026 19:12:10 +0000 Subject: [PATCH 2/2] Fix type errors in tool-call tokenization - Use .get() instead of direct [] access for tool_calls to handle message types that don't have this key - Cast message to dict[str, Any] when appending to token_template_messages --- src/art/preprocessing/tokenize.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index e3cb0a4cc..3816f66f4 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -204,12 +204,11 @@ def tokenize_trajectory( token_template_messages: list[dict[str, Any]] = [] for original, message in zip(messages_and_choices, messages): trainable_assistant = ( - (not isinstance(original, dict) and original.logprobs is not None) - or ( - allow_training_without_logprobs - and isinstance(original, dict) - and original.get("role") == "assistant" - ) + not isinstance(original, dict) and original.logprobs is not None + ) or ( + allow_training_without_logprobs + and isinstance(original, dict) + and original.get("role") == "assistant" ) if trainable_assistant: token_template_messages.append( @@ -217,14 +216,14 @@ def tokenize_trajectory( "role": "assistant", "content": sentinal_token, **( - {"tool_calls": message["tool_calls"]} - if message.get("tool_calls") + {"tool_calls": message.get("tool_calls")} # type: ignore[call-overload] + if message.get("tool_calls") # type: ignore[call-overload] else {} ), } ) else: - token_template_messages.append(message) + token_template_messages.append(cast(dict[str, Any], message)) token_ids = cast( list[int], tokenizer.apply_chat_template(