Skip to content

Multi-Turn Chat Dataset SFT#2769

Open
haydn-jones wants to merge 18 commits intopytorch:mainfrom
haydn-jones:sft-multiturn
Open

Multi-Turn Chat Dataset SFT#2769
haydn-jones wants to merge 18 commits intopytorch:mainfrom
haydn-jones:sft-multiturn

Conversation

@haydn-jones
Copy link
Copy Markdown
Contributor

@haydn-jones haydn-jones commented Mar 31, 2026

Discussed in #2757.

Overview

I've added multi-turn support to SFT. This involved:

  1. Moving away from EOS token based detection of document boundaries, relying instead on the RoPE positions (a reset to 0 indicates a new document)
  2. Updating all callers that generate block causal / other attn masks.
  3. Detect assistant token spans by rendering messages twice per chat turn, with a heuristic to detect EOT tokens.

To do assistant-only loss (which is the default), you specify the train_on in the dataset:

ChatDataLoader.Config(
  dataset_path="haydn-jones/SynthMultiTurn",
  load_dataset_kwargs={
      "split": "train",
  },
  sample_processor=process_sample,
  train_on=SupervisionMode.ASSISTANT, # <--- Specify model here
)

This PR ended up touching a lot of files, but by and large the meaningful changes are localized to:

  • tests/unit_tests/test_chat_dataset.py: Adding various tests for new functionality.
  • torchtitan/hf_datasets/text_datasets.py: Adding support for multi-turn chat datasets.
  • torchtitan/models/common/attention.py: Compute document boundaries from RoPE positions.

Notes

SigLIP2 uses get_document_mask_mod, but it doesn't actually need to as images aren't being packed. Instead of deleting it I just migrated it by building a monotonically increasing positions array. On main it looks like:

mask_mods.append(get_document_mask_mod(pixel_masks, tokenizer.eos_id))

Validation

I made a simple synthetic dataset (https://huggingface.co/datasets/haydn-jones/SynthMultiTurn) and trained Qwen3-0.6B on it. The model behaved exactly as you would expect. Piping validation-000006 through vLLM gives:

{
  "id": "chatcmpl-816a0ab269352f1d",
  "object": "chat.completion",
  "created": 1775017889,
  "model": "outputs/step-500-hf",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "\n\nANSWER: 7 | STATE: A=7, B=-3, C=-7, D=2 | LIST=[]",
        "refusal": null,
        "annotations": null,
        "audio": null,
        "function_call": null,
        "tool_calls": [],
        "reasoning": "\nCompare A, B, C, and D directly and return the largest register value.\n"
      },
    }
  ],
}

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 31, 2026
@haydn-jones
Copy link
Copy Markdown
Contributor Author

haydn-jones commented Mar 31, 2026

Geeze, ok there's actually an issue with how I'm detecting conversation boundaries. Because BPE is not compositional I tried to get around this by progressively tokenizing the chat, but with models like Qwen which add and drop thinking tokens based on position in chat my span detection is wrong.

I'm not sure exactly what to do about this. I know transformers now uses special jinja directives ({% generation %}) to detect the assistant spans, but plenty of models do not use them (like GPT-OSS 🙄). Marking as WIP until I figure this out.

@haydn-jones haydn-jones marked this pull request as draft March 31, 2026 18:23
@haydn-jones haydn-jones changed the title Sft multiturn [WIP] Sft multiturn Mar 31, 2026
@haydn-jones haydn-jones marked this pull request as ready for review April 1, 2026 04:35
@haydn-jones haydn-jones changed the title [WIP] Sft multiturn Multi-Turn Chat Dataset SFT Apr 1, 2026
@joecummings
Copy link
Copy Markdown
Member

Thanks for the thorough work here @haydn-jones

Couple things that give me pause:

  1. Per-model maintenance: Every new model needs a hand-written regex, and the GPT-OSS pattern is already complex enough to be hard to audit. I'd rather have something that works with at least most templates automatically.
  2. Readability: The char_spans_to_token_spans offset-walking logic is difficult to follow. I had to trace through it multiple times to convince myself it was correct.

I've been looking at how https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/prompt_strategies/chat_template.py does it where they find where assistant turn i lives in the token stream, render the conversation with that turn's content replaced by an empty string, tokenize both, and diff.

Some claude pseudocode:

  def _get_assistant_spans(self, messages, full_tokens):
      spans = []
      for i, msg in enumerate(messages):
          if msg["role"] != "assistant":
              continue

          # Same conversation structure, just empty content for this turn.
          # Strips extra fields (e.g. reasoning_content) so template
          # conditionals like Qwen3's <think> insertion also differ.
          dummy_messages = [
              {"role": m["role"], "content": ""} if j == i else m
              for j, m in enumerate(messages)
          ]
          dummy_tokens = tokenize(apply_chat_template(dummy_messages))

          # Scan from front to find where they diverge
          start = first index where full_tokens[start] != dummy_tokens[start]

          # Scan from back to find where they reconverge
          end = last index (in full_tokens) where they still differ

          spans.append((start, end + 1))
      return spans

I think this handles the Qwen3 problem because both renderings have the same conversation structure - the template's conditional logic fires identically. The only difference is the content of one turn. It should also handle any future template quirks without needing new regexes.

The tradeoff is 2N extra tokenization calls per sample (N = assistant turns). Tokenization is fast relative to a training step, but I acknowledge it's not free. For train_on="last_assistant" we can reduce this to a single extra call.

Thoughts? Something I might be missing here?

@haydn-jones
Copy link
Copy Markdown
Contributor Author

haydn-jones commented Apr 1, 2026

I think your suggestion is the better approach, and I’ve implemented it. I've pushed so you can look at it.

With respect to the token diffing method: it correctly finds the assistant content span, but I don’t think the diff alone can reliably recover the end of turn token.

Concrete example:

full:   ...<|im_start|>assistant\nHello there!<|im_end|>\n<|im_start|>user...
blank:  ...<|im_start|>assistant\n            <|im_end|>\n<|im_start|>user...

When scanning from the back, the two sequences reconverge at <|im_end|>\n<|im_start|>user..., because that suffix is identical in both renderings. So the diff isolates Hello there!, but not Hello there!<|im_end|>.

As far as I can tell, Axolotl also needs more than just the diff for this. Their find_turn(...) path renders twice per turn (turns_with_empty and turns_with_content) to find the content boundary, and then they have separate logic to locate EOT / EOS near that boundary.

How do you feel about this? If we want to also supervise the end of turn token, I think that requires additional machinery beyond the diff itself, and I’d be happy to add that once we agree on the right approach.

Edit: Not sure how commit hygiene is enforced here, if we settle on an approach I'd be happy to open a separate PR with cleaner diffs.

@joecummings
Copy link
Copy Markdown
Member

Thanks for implementing - looks straightforward. Yeah we have to supervise eot. Could it be as simple as a post-hoc extension that checks whether full_tokens[end] is a special token and include it?

Don't worry about commit hygiene, we'll squash on merge and unless you have 50 reverting commits, it should be fine

@haydn-jones
Copy link
Copy Markdown
Contributor Author

I checked gpt-oss, llama3, llama4, qwen3, and deepseek, and yes the heuristic is correct for all of them. The last things on my mind are:

  1. Push the EOT heuristic
  2. Probably should implement train_on=last_assistant (so we have all, assistant, last_assistant). Or whatever names we want to use.
  3. Reasoning field. gpt-oss expects the reasoning content in a thinking field, qwen3 expects it in a reasoning_content field. Not sure if it should be the responsibility of the dataset to try and handle that mapping at all (e.g., expose reasoning_field_name for the field in the dataset and reasoning_template_name for the chat template expected name). Would allow you to use the same dataset on Qwen / gpt-oss without having to like re-upload to HF. I don't mind either way.

If you have any comments on the above or anything else, let me know and I'll take care of them so we can hopefully close this out.

@joecummings
Copy link
Copy Markdown
Member

1 and 2 are fine - maybe use an Enum for 2.

WRT 3, my current thinking is to make this the responsibility of the user constructing the dataset (see how I set up the Qwen3 experiment config).

# Conflicts:
#	torchtitan/components/validate.py
#	torchtitan/models/common/decoder.py
#	torchtitan/models/gpt_oss/model.py
#	torchtitan/models/llama4/model.py
@haydn-jones
Copy link
Copy Markdown
Contributor Author

Looks like I've done everything I intended to. Merged main in and hit some issues in the validation pipeline. I'm not certain this is correct for context parallel, so this might need some scrutiny.

0582ecc#diff-a05647646c2ceafcb0c9312bebbefe6a7952fbb16e6ffa1c0bf7367680cba184R201

Comment thread torchtitan/hf_datasets/text_datasets.py Outdated
Comment on lines +66 to +68
ALL = "all"
ASSISTANT = "assistant"
LAST_ASSISTANT = "last_assistant"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As someone who doesn't know sft well in general, can I ask when would we want the three modes, respectively? E.g. my naive thinking would be ASSISTANT mode should always be used.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, you do only want to train on assistant completions only as far as I know, though I don't have any insight on real CPT after standard pretraining. I definitely always default to assistant only loss.

I have the other two to match TRL from huggingface and Axolotl. TRL defaults to ALL and axolotl technically supports LAST_ASSISTANT but it's not really advertised (you have to set roles_to_train: []).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I would lean on the side of "keeping it simple unless we know it's going to be useful". But maybe @joecummings knows more.

Copy link
Copy Markdown
Member

@joecummings joecummings Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not uncommon to be able to specify weighting per-turn (OpenAI also allows this). The best reasoning would be some datasets wherein there's a lot of "setup" and we're really only interested in having the model predict the last turn.

It is definitely not as common as just training on all assistant turns.

In the spirit of "as minimal as possible", we can omit from this PR but please leave a TODO to revisit at some point. Sorry, I got overeager here :)

Comment thread torchtitan/hf_datasets/text_datasets.py Outdated
Comment thread torchtitan/hf_datasets/text_datasets.py Outdated
Comment thread torchtitan/hf_datasets/text_datasets.py Outdated
sample_processor: Annotated[Callable, tyro.conf.Suppress]
"""Callable(sample_dict) -> list[message_dict]. Set in config functions."""

train_on: Annotated[SupervisionMode, tyro.conf.EnumChoicesFromValues] = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of naming, can we do
supervision_mode: Literal["assistant", ...]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, name seems not addressed -- but we can wait until the decision is made on whether or not to have this field.

@haydn-jones
Copy link
Copy Markdown
Contributor Author

Need to do a training run on my silly dataset to make sure everything is looking good.

@haydn-jones
Copy link
Copy Markdown
Contributor Author

haydn-jones commented Apr 5, 2026

A couple of things:

  1. The VLM CI is broken because the VLM dataset doesn't emit positions, and because of some changes that happened while I was working on this my fix no longer works around that. The correct fix would be to add position info in to the VLM dataset, and maybe a simple fast fix is emit dummy position information from the VLM collator.
  2. My E2E experiment isn't working properly with the token-span diffing method of finding assistant spans. For Qwen3, we cant supervise the opening <think> token in the assistant message because it shows up in both the blanked render and the real render. When serving, vLLM will use something like add_generation_prompt=True to start off the assistant message, but this does not include a <think> token, so my trained model seems to be getting confused and not thinking / responding as it should. I can get it to do what it should by prefilling the assistant response with a thinking token but not great. Thoughts @joecummings?
  3. Will fix the SupervisionMode thing once we decide if we want to keep configurable supervision modes.

@haydn-jones haydn-jones marked this pull request as draft April 7, 2026 01:54
@haydn-jones
Copy link
Copy Markdown
Contributor Author

A couple of things:

  1. The VLM CI is broken because the VLM dataset doesn't emit positions, and because of some changes that happened while I was working on this my fix no longer works around that. The correct fix would be to add position info in to the VLM dataset, and maybe a simple fast fix is emit dummy position information from the VLM collator.
  2. My E2E experiment isn't working properly with the token-span diffing method of finding assistant spans. For Qwen3, we cant supervise the opening <think> token in the assistant message because it shows up in both the blanked render and the real render. When serving, vLLM will use something like add_generation_prompt=True to start off the assistant message, but this does not include a <think> token, so my trained model seems to be getting confused and not thinking / responding as it should. I can get it to do what it should by prefilling the assistant response with a thinking token but not great. Thoughts @joecummings?
  3. Will fix the SupervisionMode thing once we decide if we want to keep configurable supervision modes.

Marked this as draft so CI won't run. Points (1) and (2) still need to be addressed. Will wait a few days for comments before checking back in.

@haydn-jones
Copy link
Copy Markdown
Contributor Author

@joecummings I'm a bit swamped with NeurIPS deadline stuff, but I do think I have a solution to the small tokenization issue here. Will implement when I have time. Maybe this weekend?

# Conflicts:
#	tests/unit_tests/test_chat_dataset.py
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 18, 2026

Workflows were awaiting approval. CI has now been triggered for the ciflow labels on this PR.

@haydn-jones
Copy link
Copy Markdown
Contributor Author

@joecummings Long story short the issue I saw with my toy dataset and Qwen3-0.6B was more an issue with Qwen3-0.6B's specific chat template (newer Qwen templates inject an opening <think> token when setting add_generation_prompt=True which solves it). I fixed the VLM experiment by adding positions into the dataloader (I don't think there is anything I need to do for the VLM in the main repo).

I accidentally ran ufmt on too many files and touched the forge experiment :( can revert.

The last question I have before I feel like I'm done is: Should chat template arguments (like enable_thinking) be exposed somehow to the user through something like process_sample?

@joecummings
Copy link
Copy Markdown
Member

The last question I have before I feel like I'm done is: Should chat template arguments (like enable_thinking) be exposed somehow to the user through something like process_sample?

process_sample should largely be user defined, which means if they are loading in a chat dataset that utilizes reasoning traces, they should insert reasoning tags like so. For our current setup, maybe we need to pass along chat_template_kwargs from the dataloader. Eventually we should really have better staging between: 1) getting data into a loose "message" structure and 2) applying model transform (tokenization) for training.

@haydn-jones
Copy link
Copy Markdown
Contributor Author

@joecummings Ok yeah I'm with you.

@haydn-jones haydn-jones marked this pull request as ready for review April 30, 2026 18:33
Allows an optional leading 'system' message, then requires
alternating user/assistant turns ending with 'assistant'.
"""
if len(messages) < 2:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this exception is covered by the one below, can we just keep that one?

self._assistant_end_token_ids.update(
token_id
for token_id, token in tokenizer.tokenizer.get_added_tokens_decoder().items()
if token.special
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how robust is this -- does a special token always belong to "end tokens"?


_, dummy_tokens = self._render_conversation(blanked_messages)

# Scan from front to find where they diverge.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this algorithm sounds correct, but

  • it's hard to read
  • the complexity seems Omega(num assistant * length of full tokens)

Do you think we can extend the previous algorithm

  • The messages is alternating between users and assistants
  • To figure out the start of assistant i (with in a user-assistant pair), we apply_chat_template(add_generation_prompt=True) to user i
  • To figure out the length of assistant i, we apply_chat_template to user i + assistant i
  • and we can iterate over all pairs sequentially

# and last supervised position is end - 2
# (predicting full_tokens[end-1] from position end-2).
label_start = max(start - 1, 0)
label_end = min(end - 1, len(label_ids))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds strange to use len(labels_ids) in multi-turn -- end should suffice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants