Skip to content

Add prefix-preserving training chat template for GPT-OSS#5109

Open
qgallouedec wants to merge 12 commits intomainfrom
support-gpt-oss
Open

Add prefix-preserving training chat template for GPT-OSS#5109
qgallouedec wants to merge 12 commits intomainfrom
support-gpt-oss

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Feb 17, 2026

PR Description

This PR adds support for tool-calling training with GPT-OSS models (e.g., gpt-oss-20b) in the GRPO agent training pipeline, extending the existing Qwen3 support.

Problem

For context, the main challenge is to ensure that the chat template used for training is prefix-preserving, meaning that when new messages are appended to a conversation, the previous sequence of tokens remains unchanged. This is crucial for multi-turn training.

The original GPT-OSS chat template is not prefix-preserving for two reasons:

  • <|return|> vs <|end|>: The last assistant message ends with <|return|> while all other turns use <|end|>. When a conversation is extended with new turns, the previously-final assistant message switches from <|return|> to <|end|>, breaking the prefix.

    >>> from transformers import AutoTokenizer
    >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
    >>> messages = [
    ...     {"role": "user", "content": "What is 2+2?."},
    ...     {"role": "assistant", "content": "4"},
    ...     {"role": "user", "content": "And what about 3+3?"},
    ... ]
    >>> tokenizer.apply_chat_template(messages[:2], tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...].<|end|><|start|>user<|message|>What is 2+2?.<|end|><|start|>assistant<|channel|>final<|message|>4<|return|>'
    >>> tokenizer.apply_chat_template(messages, tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...].<|end|><|start|>user<|message|>What is 2+2?.<|end|><|start|>assistant<|channel|>final<|message|>4<|end|><|start|>user<|message|>And what about 3+3?<|end|>'
  • Conditional thinking blocks: (Same as Qwen3) The thinking field is only rendered for the final assistant turn (via loop.last in the template), so earlier assistant turns lose their thinking content when new messages are appended.

    >>> from transformers import AutoTokenizer
    >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
    >>> messages = [
    ...     {"role": "user", "content": "What is 2+2?."},
    ...     {"role": "assistant", "thinking": "🤔", "content": "4"},
    ...     {"role": "user", "content": "And what about 3+3?"},
    ... ]
    >>> tokenizer.apply_chat_template(messages[:2], tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...]<|start|>assistant<|channel|>analysis<|message|>🤔<|end|><|start|>assistant<|channel|>final<|message|>4<|return|>'
    >>> tokenizer.apply_chat_template(messages, tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...]<|start|>assistant<|channel|>final<|message|>4<|end|><|start|>user<|message|>And what about 3+3?<|end|>'

This PR introduces

  • gpt_oss_chat_template — A reference copy of the original GPT-OSS chat template stored in chat_template_utils.py for template matching.
  • gpt_oss_training_chat_template — A modified training-safe template with two key changes:
    • Replaces <|return|> with <|end|> on the final assistant message to ensure consistent turn delimiters across all turns.
    • Changes loop.last to true so thinking blocks are always rendered, not just on the final turn.
  • Updated get_training_chat_template() — Now recognizes GPT-OSS templates and returns the training variant, alongside the existing Qwen3 support.
  • Updated is_chat_template_prefix_preserving() — Test messages now include both reasoning_content and thinking keys, since GPT-OSS uses thinking while Qwen3 uses reasoning_content.
  • Extended tests — All TestGetTrainingChatTemplate tests are now parameterized over both GPT-OSS and Qwen3, with a helper _assert_equal that accounts for the expected <|return|><|end|> difference.

Design notes (happy to get thoughts on this!)

The consequence of this is that the <|return|> token is not seen during training. My intuition is that, since GRPO uses a comparative objective (relative ranking within groups), the model's ability to produce <|return|> at inference is not expected to degrade.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Support tool-calling training for GPT-OSS Add prefix-preserving training chat template for GPT-OSS Feb 17, 2026
Base automatically changed from more-test-get_training_chat_template to main February 18, 2026 12:45
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

I think:

  • The prefix-preserving fix is good and likely necessary.
  • But I would say the Design note’s justification is a weak point: I think comparative objective doesn't guarantee preservation of an unseen special token; this could also impact stopping and parsing.

I would prefer others to comment on this.

@qgallouedec
Copy link
Member Author

But I would say the Design note’s justification is a weak point: I think comparative objective doesn't guarantee preservation of an unseen special token; this could also impact stopping and parsing.

100% agree. At this point it's an intuition. I'll update the comment to be more conservative.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments