Multi-Turn Chat Dataset SFT#2769
Conversation
|
|
|
Thanks for the thorough work here @haydn-jones Couple things that give me pause:
I've been looking at how https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/prompt_strategies/chat_template.py does it where they find where assistant turn i lives in the token stream, render the conversation with that turn's content replaced by an empty string, tokenize both, and diff. Some claude pseudocode: def _get_assistant_spans(self, messages, full_tokens):
spans = []
for i, msg in enumerate(messages):
if msg["role"] != "assistant":
continue
# Same conversation structure, just empty content for this turn.
# Strips extra fields (e.g. reasoning_content) so template
# conditionals like Qwen3's <think> insertion also differ.
dummy_messages = [
{"role": m["role"], "content": ""} if j == i else m
for j, m in enumerate(messages)
]
dummy_tokens = tokenize(apply_chat_template(dummy_messages))
# Scan from front to find where they diverge
start = first index where full_tokens[start] != dummy_tokens[start]
# Scan from back to find where they reconverge
end = last index (in full_tokens) where they still differ
spans.append((start, end + 1))
return spansI think this handles the Qwen3 problem because both renderings have the same conversation structure - the template's conditional logic fires identically. The only difference is the content of one turn. It should also handle any future template quirks without needing new regexes. The tradeoff is 2N extra tokenization calls per sample (N = assistant turns). Tokenization is fast relative to a training step, but I acknowledge it's not free. For train_on="last_assistant" we can reduce this to a single extra call. Thoughts? Something I might be missing here? |
|
I think your suggestion is the better approach, and I’ve implemented it. I've pushed so you can look at it. With respect to the token diffing method: it correctly finds the assistant content span, but I don’t think the diff alone can reliably recover the end of turn token. Concrete example: When scanning from the back, the two sequences reconverge at As far as I can tell, Axolotl also needs more than just the diff for this. Their How do you feel about this? If we want to also supervise the end of turn token, I think that requires additional machinery beyond the diff itself, and I’d be happy to add that once we agree on the right approach. Edit: Not sure how commit hygiene is enforced here, if we settle on an approach I'd be happy to open a separate PR with cleaner diffs. |
|
Thanks for implementing - looks straightforward. Yeah we have to supervise eot. Could it be as simple as a post-hoc extension that checks whether full_tokens[end] is a special token and include it? Don't worry about commit hygiene, we'll squash on merge and unless you have 50 reverting commits, it should be fine |
|
I checked gpt-oss, llama3, llama4, qwen3, and deepseek, and yes the heuristic is correct for all of them. The last things on my mind are:
If you have any comments on the above or anything else, let me know and I'll take care of them so we can hopefully close this out. |
|
1 and 2 are fine - maybe use an Enum for 2. WRT 3, my current thinking is to make this the responsibility of the user constructing the dataset (see how I set up the Qwen3 experiment config). |
# Conflicts: # torchtitan/components/validate.py # torchtitan/models/common/decoder.py # torchtitan/models/gpt_oss/model.py # torchtitan/models/llama4/model.py
|
Looks like I've done everything I intended to. Merged main in and hit some issues in the validation pipeline. I'm not certain this is correct for context parallel, so this might need some scrutiny. 0582ecc#diff-a05647646c2ceafcb0c9312bebbefe6a7952fbb16e6ffa1c0bf7367680cba184R201 |
| ALL = "all" | ||
| ASSISTANT = "assistant" | ||
| LAST_ASSISTANT = "last_assistant" |
There was a problem hiding this comment.
As someone who doesn't know sft well in general, can I ask when would we want the three modes, respectively? E.g. my naive thinking would be ASSISTANT mode should always be used.
There was a problem hiding this comment.
Generally, you do only want to train on assistant completions only as far as I know, though I don't have any insight on real CPT after standard pretraining. I definitely always default to assistant only loss.
I have the other two to match TRL from huggingface and Axolotl. TRL defaults to ALL and axolotl technically supports LAST_ASSISTANT but it's not really advertised (you have to set roles_to_train: []).
There was a problem hiding this comment.
Thanks. I would lean on the side of "keeping it simple unless we know it's going to be useful". But maybe @joecummings knows more.
There was a problem hiding this comment.
It's not uncommon to be able to specify weighting per-turn (OpenAI also allows this). The best reasoning would be some datasets wherein there's a lot of "setup" and we're really only interested in having the model predict the last turn.
It is definitely not as common as just training on all assistant turns.
In the spirit of "as minimal as possible", we can omit from this PR but please leave a TODO to revisit at some point. Sorry, I got overeager here :)
| sample_processor: Annotated[Callable, tyro.conf.Suppress] | ||
| """Callable(sample_dict) -> list[message_dict]. Set in config functions.""" | ||
|
|
||
| train_on: Annotated[SupervisionMode, tyro.conf.EnumChoicesFromValues] = ( |
There was a problem hiding this comment.
In terms of naming, can we do
supervision_mode: Literal["assistant", ...]
There was a problem hiding this comment.
oops, name seems not addressed -- but we can wait until the decision is made on whether or not to have this field.
|
Need to do a training run on my silly dataset to make sure everything is looking good. |
|
A couple of things:
|
Marked this as draft so CI won't run. Points (1) and (2) still need to be addressed. Will wait a few days for comments before checking back in. |
|
@joecummings I'm a bit swamped with NeurIPS deadline stuff, but I do think I have a solution to the small tokenization issue here. Will implement when I have time. Maybe this weekend? |
# Conflicts: # tests/unit_tests/test_chat_dataset.py
|
|
|
@joecummings Long story short the issue I saw with my toy dataset and I accidentally ran ufmt on too many files and touched the forge experiment :( can revert. The last question I have before I feel like I'm done is: Should chat template arguments (like |
|
|
@joecummings Ok yeah I'm with you. |
| Allows an optional leading 'system' message, then requires | ||
| alternating user/assistant turns ending with 'assistant'. | ||
| """ | ||
| if len(messages) < 2: |
There was a problem hiding this comment.
this exception is covered by the one below, can we just keep that one?
| self._assistant_end_token_ids.update( | ||
| token_id | ||
| for token_id, token in tokenizer.tokenizer.get_added_tokens_decoder().items() | ||
| if token.special |
There was a problem hiding this comment.
how robust is this -- does a special token always belong to "end tokens"?
|
|
||
| _, dummy_tokens = self._render_conversation(blanked_messages) | ||
|
|
||
| # Scan from front to find where they diverge. |
There was a problem hiding this comment.
hmm this algorithm sounds correct, but
- it's hard to read
- the complexity seems Omega(num assistant * length of full tokens)
Do you think we can extend the previous algorithm
- The messages is alternating between users and assistants
- To figure out the start of assistant
i(with in a user-assistant pair), weapply_chat_template(add_generation_prompt=True)to useri - To figure out the length of assistant
i, weapply_chat_templateto user i + assistant i - and we can iterate over all pairs sequentially
| # and last supervised position is end - 2 | ||
| # (predicting full_tokens[end-1] from position end-2). | ||
| label_start = max(start - 1, 0) | ||
| label_end = min(end - 1, len(label_ids)) |
There was a problem hiding this comment.
sounds strange to use len(labels_ids) in multi-turn -- end should suffice
Discussed in #2757.
Overview
I've added multi-turn support to SFT. This involved:
To do assistant-only loss (which is the default), you specify the
train_onin the dataset:This PR ended up touching a lot of files, but by and large the meaningful changes are localized to:
tests/unit_tests/test_chat_dataset.py: Adding various tests for new functionality.torchtitan/hf_datasets/text_datasets.py: Adding support for multi-turn chat datasets.torchtitan/models/common/attention.py: Compute document boundaries from RoPE positions.Notes
SigLIP2 uses
get_document_mask_mod, but it doesn't actually need to as images aren't being packed. Instead of deleting it I just migrated it by building a monotonically increasing positions array. On main it looks like:torchtitan/torchtitan/experiments/vlm/model/siglip2.py
Line 239 in 2fdf2e4
Validation
I made a simple synthetic dataset (https://huggingface.co/datasets/haydn-jones/SynthMultiTurn) and trained Qwen3-0.6B on it. The model behaved exactly as you would expect. Piping
validation-000006through vLLM gives:{ "id": "chatcmpl-816a0ab269352f1d", "object": "chat.completion", "created": 1775017889, "model": "outputs/step-500-hf", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "\n\nANSWER: 7 | STATE: A=7, B=-3, C=-7, D=2 | LIST=[]", "refusal": null, "annotations": null, "audio": null, "function_call": null, "tool_calls": [], "reasoning": "\nCompare A, B, C, and D directly and return the largest register value.\n" }, } ], }