diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index a9807b052..722fa446a 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -1,15 +1,19 @@ -from dataclasses import dataclass -from itertools import takewhile +# ruff: noqa: I001 +# Import order is intentional - unsloth MUST be imported before transformers import math import random +from dataclasses import dataclass +from itertools import takewhile from typing import Any, Generator, cast -from PIL import Image +import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var + import torch +from PIL import Image from transformers.image_processing_utils import BaseImageProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from ..trajectories import History, TrajectoryGroup, get_messages +from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages @dataclass @@ -44,6 +48,23 @@ def without_prompt(self) -> "TokenizedResult": ) +@dataclass +class SFTBatch: + """A batch of tokenized trajectories for supervised fine-tuning. + Attributes: + trajectory_tensors: List of tensor dictionaries, one per trajectory. + Each dict contains 'input_ids', 'attention_mask', and 'labels'. + learning_rate: Learning rate to use for this batch. + num_trajectories: Number of trajectories in this batch. + num_trainable_tokens: Total number of tokens being trained on (labels != -100). + """ + + trajectory_tensors: list[dict[str, torch.Tensor]] + learning_rate: float + num_trajectories: int + num_trainable_tokens: int + + def tokenize_trajectory_groups( tokenizer: "PreTrainedTokenizerBase", trajectory_groups: list[TrajectoryGroup], @@ -312,3 +333,133 @@ def tokenize_trajectory( pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) + + +def tokenize_sft_batches( + trajectories: list[Trajectory], + batch_size: int, + learning_rates: list[float], + tokenizer: PreTrainedTokenizerBase, + instruction_part: str, + response_part: str, +) -> Generator[SFTBatch, None, None]: + """ + Tokenize trajectories into batches for supervised fine-tuning. + Args: + trajectories: Flat list of trajectories + batch_size: Number of trajectories per batch + learning_rates: Learning rate for each batch + tokenizer: Tokenizer to use for encoding + instruction_part: Instruction template part (e.g., "User:") + response_part: Response template part (e.g., "Assistant:") + Yields: + SFTBatch object containing: + - trajectory_tensors: List of tensors for each trajectory + - learning_rate: Learning rate for this batch + - num_trajectories: Number of trajectories in this batch + - num_trainable_tokens: Total number of trainable tokens + """ + # Import Unsloth Zoo utility for training on responses only + # Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py + # This function handles edge cases with tokenization (newlines, spaces, etc.) + from unsloth_zoo.dataset_utils import train_on_responses_only + + # Validate inputs + num_trajectories = len(trajectories) + num_learning_rates = len(learning_rates) + expected_num_batches = math.ceil(num_trajectories / batch_size) + + if num_learning_rates != expected_num_batches: + raise ValueError( + f"Mismatch between trajectories and learning_rates: " + f"{num_trajectories} trajectories with batch_size={batch_size} " + f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates" + ) + + # Handle missing pad_token_id (common for LLaMA and similar models) + pad_token_id = tokenizer.pad_token_id + if pad_token_id is None: + pad_token_id = tokenizer.eos_token_id + + _train_on_responses_only = train_on_responses_only( + trainer=None, + instruction_part=instruction_part, + response_part=response_part, + force_match=False, + tokenizer=tokenizer, + return_function=True, + ) + + # TODO Process input_ids in batch for better efficiency + for batch_idx, lr in enumerate(learning_rates): + start_idx = batch_idx * batch_size + end_idx = start_idx + batch_size + trajectory_batch = trajectories[start_idx:end_idx] + + # First pass: tokenize all trajectories + tokenized_trajectories = [] + for trajectory in trajectory_batch: + messages = trajectory.messages_and_choices + tools = trajectory.tools + + # Single-step tokenization: apply_chat_template with tokenize=True + input_ids = cast( + list[int], + tokenizer.apply_chat_template( + cast(Any, messages), + tools=cast(Any, tools), + tokenize=True, + add_generation_prompt=False, + ), + ) + + # Create attention mask (all 1s - no padding yet) + attention_mask = [1] * len(input_ids) + + labels = _train_on_responses_only({"input_ids": [input_ids]})["labels"][0] + + tokenized_trajectories.append( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + ) + + # Find max length in this batch for padding + max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories) + + # Second pass: pad all trajectories to max_seq_length + trajectory_tensors = [] + for tokenized in tokenized_trajectories: + input_ids = tokenized["input_ids"] + attention_mask = tokenized["attention_mask"] + labels = tokenized["labels"] + + # Pad to max_seq_length + padding_length = max_seq_length - len(input_ids) + if padding_length > 0: + input_ids = input_ids + [pad_token_id] * padding_length + attention_mask = attention_mask + [0] * padding_length + labels = labels + [-100] * padding_length + + trajectory_tensor = { + "input_ids": torch.tensor([input_ids], dtype=torch.long), + "attention_mask": torch.tensor([attention_mask], dtype=torch.long), + "labels": torch.tensor([labels], dtype=torch.long), + } + + trajectory_tensors.append(trajectory_tensor) + + # Calculate total trainable tokens (labels != -100) + num_trainable_tokens = sum( + (tensor_dict["labels"] != -100).sum().item() + for tensor_dict in trajectory_tensors + ) + + yield SFTBatch( + trajectory_tensors=trajectory_tensors, + learning_rate=lr, + num_trajectories=len(trajectory_tensors), + num_trainable_tokens=num_trainable_tokens, + )