Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,32 +201,46 @@ def tokenize_trajectory(
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
)
sentinal_token = tokenizer.decode(sentinal_token_id)
token_template_messages: list[dict[str, Any]] = []
for original, message in zip(messages_and_choices, messages):
trainable_assistant = (
not isinstance(original, dict) and original.logprobs is not None
) or (
allow_training_without_logprobs
and isinstance(original, dict)
and original.get("role") == "assistant"
)
if trainable_assistant:
token_template_messages.append(
{
"role": "assistant",
"content": sentinal_token,
**(
{"tool_calls": message.get("tool_calls")} # type: ignore[call-overload]
if message.get("tool_calls") # type: ignore[call-overload]
else {}
),
}
)
else:
token_template_messages.append(cast(dict[str, Any], message))
token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(
list[dict],
[
(
message_or_choice
if isinstance(message_or_choice, dict)
and not message_or_choice["role"] == "assistant"
else {
"role": "assistant",
"content": sentinal_token,
}
)
for message_or_choice in messages_and_choices
],
),
cast(list[dict], token_template_messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
assistant_mask: list[int] = [0] * len(token_ids)
logprobs = [float("nan")] * len(token_ids)
for message in messages_and_choices:
if isinstance(message, dict) and not message["role"] == "assistant":
if isinstance(message, dict):
if message["role"] != "assistant":
continue
if not allow_training_without_logprobs:
continue
elif message.logprobs is None and not allow_training_without_logprobs:
continue
start = token_ids.index(sentinal_token_id)
end = start + 1
Expand All @@ -235,6 +249,11 @@ def tokenize_trajectory(
except IndexError:
end_token_id = None
if isinstance(message, dict):
if message.get("tool_calls"):
raise ValueError(
"Assistant message has tool_calls but is being tokenized "
"via tokenizer.encode(content). This path ignores tool calls."
)
content = message.get("content")
assert isinstance(content, str)
content_token_ids = tokenizer.encode(
Expand Down