From ec3ee1891606c3d8a369d22f64c687fb16e67354 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 20 Jan 2026 11:31:10 -0800 Subject: [PATCH 1/6] SFT preprocessing --- src/art/preprocessing/tokenize.py | 236 +++++++++++++++++++++++++++++- 1 file changed, 235 insertions(+), 1 deletion(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index a9807b052..8f2f14e55 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -9,7 +9,12 @@ from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from ..trajectories import History, TrajectoryGroup, get_messages +from ..trajectories import History, TrajectoryGroup, Trajectory, get_messages + +# Import Unsloth Zoo utilities for robust token matching +# Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py +# These functions handle edge cases with tokenization (newlines, spaces, etc.) +from unsloth_zoo.dataset_utils import _find_common_token_ids @dataclass @@ -44,6 +49,23 @@ def without_prompt(self) -> "TokenizedResult": ) +@dataclass +class SFTBatch: + """A batch of tokenized trajectories for supervised fine-tuning. + Attributes: + trajectory_tensors: List of tensor dictionaries, one per trajectory. + Each dict contains 'input_ids', 'attention_mask', and 'labels'. + learning_rate: Learning rate to use for this batch. + num_trajectories: Number of trajectories in this batch. + num_trainable_tokens: Total number of tokens being trained on (labels != -100). + """ + + trajectory_tensors: list[dict[str, torch.Tensor]] + learning_rate: float + num_trajectories: int + num_trainable_tokens: int + + def tokenize_trajectory_groups( tokenizer: "PreTrainedTokenizerBase", trajectory_groups: list[TrajectoryGroup], @@ -312,3 +334,215 @@ def tokenize_trajectory( pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) + +def tokenize_sft_batches( + trajectories: list[Trajectory], + batch_size: int, + learning_rates: list[float], + tokenizer: PreTrainedTokenizerBase, + instruction_part: str, + response_part: str, +) -> Generator[SFTBatch, None, None]: + """ + Tokenize trajectories into batches for supervised fine-tuning. + Args: + trajectories: Flat list of trajectories + batch_size: Number of trajectories per batch + learning_rates: Learning rate for each batch + tokenizer: Tokenizer to use for encoding + instruction_part: Instruction template part (e.g., "User:") + response_part: Response template part (e.g., "Assistant:") + Yields: + SFTBatch object containing: + - trajectory_tensors: List of tensors for each trajectory + - learning_rate: Learning rate for this batch + - num_trajectories: Number of trajectories in this batch + - num_trainable_tokens: Total number of trainable tokens + """ + # Validate inputs + num_trajectories = len(trajectories) + num_learning_rates = len(learning_rates) + expected_num_batches = math.ceil(num_trajectories / batch_size) + + if num_learning_rates != expected_num_batches: + raise ValueError( + f"Mismatch between trajectories and learning_rates: " + f"{num_trajectories} trajectories with batch_size={batch_size} " + f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates" + ) + + # Handle missing pad_token_id (common for LLaMA and similar models) + pad_token_id = tokenizer.pad_token_id + if pad_token_id is None: + pad_token_id = tokenizer.eos_token_id + + # Get most common tokens using Unsloth approach + Q_must, Q_left, Q_right = _find_common_token_ids( + instruction_part, tokenizer, force_match=False + ) + A_must, A_left, A_right = _find_common_token_ids( + response_part, tokenizer, force_match=False + ) + + # Store temporary stuff + A_first = A_must[0] + len_A_must = len(A_must) + A_left_reversed = A_left[::-1] + A_right_forward = A_right + + Q_first = Q_must[0] + len_Q_must = len(Q_must) + Q_left_reversed = Q_left[::-1] + Q_right_forward = Q_right + + def _train_on_responses_only(input_ids: list[int]) -> list[int]: + """Unsloth-based implementation for marking trainable tokens.""" + n = len(input_ids) + labels = [-100] * n + n_minus_1 = n - 1 + j = 0 + + while j < n: + # Find + if (input_ids[j] == A_first) and ( + input_ids[j : (k := j + len_A_must)] == A_must + ): + # Now backtrack to get previous optional tokens + for optional_left in A_left_reversed: + if j < 1: + break + if optional_left == input_ids[j - 1]: + j -= 1 + else: + break + + # And forwards look as well + for optional_right in A_right_forward: + if k >= n_minus_1: + break + if optional_right == input_ids[k + 1]: + k += 1 + else: + break + + assistant_k = k + j = assistant_k + + # Given , now find next user + while j < n: + # Find + # Also accept last final item if assistant is the last turn + if (j == n_minus_1) or ( + (input_ids[j] == Q_first) + and (input_ids[j : (k := j + len_Q_must)] == Q_must) + ): + # Now backtrack to get previous optional tokens + for optional_left in Q_left_reversed: + if j < 1: + break + if optional_left == input_ids[j - 1]: + j -= 1 + else: + break + + # And forwards look as well + for optional_right in Q_right_forward: + if k >= n_minus_1: + break + if optional_right == input_ids[k + 1]: + k += 1 + else: + break + + user_j = j + + # Account for last item + if user_j != n_minus_1: + j = k + else: + user_j = n + k = n + + # Now copy input_ids to labels + labels[assistant_k:user_j] = input_ids[assistant_k:user_j] + break + + j += 1 + + j += 1 + + return labels + + # Batch trajectories + for batch_idx, lr in enumerate(learning_rates): + start_idx = batch_idx * batch_size + end_idx = start_idx + batch_size + trajectory_batch = trajectories[start_idx:end_idx] + + # First pass: tokenize all trajectories + tokenized_trajectories = [] + for trajectory in trajectory_batch: + messages = trajectory.messages_and_choices + tools = trajectory.tools + + # Single-step tokenization: apply_chat_template with tokenize=True + input_ids = cast( + list[int], + tokenizer.apply_chat_template( + cast(Any, messages), + tools=cast(Any, tools), + tokenize=True, + add_generation_prompt=False, + ), + ) + + # Create attention mask (all 1s - no padding yet) + attention_mask = [1] * len(input_ids) + + labels = _train_on_responses_only(input_ids) + + tokenized_trajectories.append( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + ) + + # Find max length in this batch for padding + max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories) + + # Second pass: pad all trajectories to max_seq_length + trajectory_tensors = [] + for tokenized in tokenized_trajectories: + input_ids = tokenized["input_ids"] + attention_mask = tokenized["attention_mask"] + labels = tokenized["labels"] + + # Pad to max_seq_length + padding_length = max_seq_length - len(input_ids) + if padding_length > 0: + input_ids = input_ids + [pad_token_id] * padding_length + attention_mask = attention_mask + [0] * padding_length + labels = labels + [-100] * padding_length + + trajectory_tensor = { + "input_ids": torch.tensor([input_ids], dtype=torch.long), + "attention_mask": torch.tensor([attention_mask], dtype=torch.long), + "labels": torch.tensor([labels], dtype=torch.long), + } + + trajectory_tensors.append(trajectory_tensor) + + # Calculate total trainable tokens (labels != -100) + num_trainable_tokens = sum( + (tensor_dict["labels"] != -100).sum().item() + for tensor_dict in trajectory_tensors + ) + + yield SFTBatch( + trajectory_tensors=trajectory_tensors, + learning_rate=lr, + num_trajectories=len(trajectory_tensors), + num_trainable_tokens=num_trainable_tokens, + ) \ No newline at end of file From 686e097266597918bc5a120753f1b3ef905441b1 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 20 Jan 2026 13:16:26 -0800 Subject: [PATCH 2/6] Use unsloth-zoo _train_on_response_only --- src/art/preprocessing/tokenize.py | 115 ++++-------------------------- 1 file changed, 15 insertions(+), 100 deletions(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 8f2f14e55..7901a491e 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -11,10 +11,11 @@ from ..trajectories import History, TrajectoryGroup, Trajectory, get_messages -# Import Unsloth Zoo utilities for robust token matching +# Import Unsloth Zoo utility for training on responses only # Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py -# These functions handle edge cases with tokenization (newlines, spaces, etc.) -from unsloth_zoo.dataset_utils import _find_common_token_ids +# This function handles edge cases with tokenization (newlines, spaces, etc.) +import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var +from unsloth_zoo.dataset_utils import train_on_responses_only @dataclass @@ -376,104 +377,18 @@ def tokenize_sft_batches( if pad_token_id is None: pad_token_id = tokenizer.eos_token_id - # Get most common tokens using Unsloth approach - Q_must, Q_left, Q_right = _find_common_token_ids( - instruction_part, tokenizer, force_match=False + # Get the _train_on_responses_only function from unsloth_zoo + # This handles edge cases with tokenization (newlines, spaces, etc.) + _train_on_responses_only = train_on_responses_only( + trainer=None, + instruction_part=instruction_part, + response_part=response_part, + force_match=False, + tokenizer=tokenizer, + return_function=True, ) - A_must, A_left, A_right = _find_common_token_ids( - response_part, tokenizer, force_match=False - ) - - # Store temporary stuff - A_first = A_must[0] - len_A_must = len(A_must) - A_left_reversed = A_left[::-1] - A_right_forward = A_right - - Q_first = Q_must[0] - len_Q_must = len(Q_must) - Q_left_reversed = Q_left[::-1] - Q_right_forward = Q_right - - def _train_on_responses_only(input_ids: list[int]) -> list[int]: - """Unsloth-based implementation for marking trainable tokens.""" - n = len(input_ids) - labels = [-100] * n - n_minus_1 = n - 1 - j = 0 - - while j < n: - # Find - if (input_ids[j] == A_first) and ( - input_ids[j : (k := j + len_A_must)] == A_must - ): - # Now backtrack to get previous optional tokens - for optional_left in A_left_reversed: - if j < 1: - break - if optional_left == input_ids[j - 1]: - j -= 1 - else: - break - - # And forwards look as well - for optional_right in A_right_forward: - if k >= n_minus_1: - break - if optional_right == input_ids[k + 1]: - k += 1 - else: - break - - assistant_k = k - j = assistant_k - - # Given , now find next user - while j < n: - # Find - # Also accept last final item if assistant is the last turn - if (j == n_minus_1) or ( - (input_ids[j] == Q_first) - and (input_ids[j : (k := j + len_Q_must)] == Q_must) - ): - # Now backtrack to get previous optional tokens - for optional_left in Q_left_reversed: - if j < 1: - break - if optional_left == input_ids[j - 1]: - j -= 1 - else: - break - - # And forwards look as well - for optional_right in Q_right_forward: - if k >= n_minus_1: - break - if optional_right == input_ids[k + 1]: - k += 1 - else: - break - - user_j = j - - # Account for last item - if user_j != n_minus_1: - j = k - else: - user_j = n - k = n - - # Now copy input_ids to labels - labels[assistant_k:user_j] = input_ids[assistant_k:user_j] - break - - j += 1 - - j += 1 - - return labels - # Batch trajectories + # TODO Process input_ids in batch for better efficiency for batch_idx, lr in enumerate(learning_rates): start_idx = batch_idx * batch_size end_idx = start_idx + batch_size @@ -499,7 +414,7 @@ def _train_on_responses_only(input_ids: list[int]) -> list[int]: # Create attention mask (all 1s - no padding yet) attention_mask = [1] * len(input_ids) - labels = _train_on_responses_only(input_ids) + labels = _train_on_responses_only({"input_ids": [input_ids]})["labels"][0] tokenized_trajectories.append( { From bf9394a926d3e7811c455b1eda7c7d587b6af987 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 20 Jan 2026 13:22:37 -0800 Subject: [PATCH 3/6] Fix ruff check --- src/art/preprocessing/tokenize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 7901a491e..3f2da3589 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -9,14 +9,14 @@ from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from ..trajectories import History, TrajectoryGroup, Trajectory, get_messages - # Import Unsloth Zoo utility for training on responses only # Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py # This function handles edge cases with tokenization (newlines, spaces, etc.) import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var from unsloth_zoo.dataset_utils import train_on_responses_only +from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages + @dataclass class TokenizedResult: From 4c82d4cd97585a0c9ac685b195921c4934e832d4 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 20 Jan 2026 13:36:46 -0800 Subject: [PATCH 4/6] reorder import --- src/art/preprocessing/tokenize.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 3f2da3589..b75968688 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -1,18 +1,21 @@ -from dataclasses import dataclass -from itertools import takewhile +# ruff: noqa: I001 +# Import order is intentional - unsloth MUST be imported before transformers import math import random +from dataclasses import dataclass +from itertools import takewhile from typing import Any, Generator, cast -from PIL import Image +import unsloth # noqa: F401 + import torch +from PIL import Image from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase # Import Unsloth Zoo utility for training on responses only # Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py # This function handles edge cases with tokenization (newlines, spaces, etc.) -import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var from unsloth_zoo.dataset_utils import train_on_responses_only from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages @@ -336,6 +339,7 @@ def tokenize_trajectory( image_grid_thw=image_grid_thw, ) + def tokenize_sft_batches( trajectories: list[Trajectory], batch_size: int, @@ -460,4 +464,4 @@ def tokenize_sft_batches( learning_rate=lr, num_trajectories=len(trajectory_tensors), num_trainable_tokens=num_trainable_tokens, - ) \ No newline at end of file + ) From a6d5ce2c87606db1eba33c5ae62434a8079253db Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 20 Jan 2026 13:39:26 -0800 Subject: [PATCH 5/6] Add comment for unsloth import --- src/art/preprocessing/tokenize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index b75968688..910eb9947 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -6,7 +6,7 @@ from itertools import takewhile from typing import Any, Generator, cast -import unsloth # noqa: F401 +import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var import torch from PIL import Image From b1e89d4a953f508113ce038f8485c46484cac3d9 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 20 Jan 2026 13:54:03 -0800 Subject: [PATCH 6/6] Import unsloth-zoo inside of tokenize_sft_batches --- src/art/preprocessing/tokenize.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 910eb9947..722fa446a 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -13,11 +13,6 @@ from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase -# Import Unsloth Zoo utility for training on responses only -# Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py -# This function handles edge cases with tokenization (newlines, spaces, etc.) -from unsloth_zoo.dataset_utils import train_on_responses_only - from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages @@ -364,6 +359,11 @@ def tokenize_sft_batches( - num_trajectories: Number of trajectories in this batch - num_trainable_tokens: Total number of trainable tokens """ + # Import Unsloth Zoo utility for training on responses only + # Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py + # This function handles edge cases with tokenization (newlines, spaces, etc.) + from unsloth_zoo.dataset_utils import train_on_responses_only + # Validate inputs num_trajectories = len(trajectories) num_learning_rates = len(learning_rates) @@ -381,8 +381,6 @@ def tokenize_sft_batches( if pad_token_id is None: pad_token_id = tokenizer.eos_token_id - # Get the _train_on_responses_only function from unsloth_zoo - # This handles edge cases with tokenization (newlines, spaces, etc.) _train_on_responses_only = train_on_responses_only( trainer=None, instruction_part=instruction_part,