Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 155 additions & 4 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from dataclasses import dataclass
from itertools import takewhile
# ruff: noqa: I001
# Import order is intentional - unsloth MUST be imported before transformers
import math
import random
from dataclasses import dataclass
from itertools import takewhile
from typing import Any, Generator, cast

from PIL import Image
import unsloth # noqa: F401 # Must import first to set UNSLOTH_IS_PRESENT env var

import torch
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from ..trajectories import History, TrajectoryGroup, get_messages
from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages


@dataclass
Expand Down Expand Up @@ -44,6 +48,23 @@ def without_prompt(self) -> "TokenizedResult":
)


@dataclass
class SFTBatch:
"""A batch of tokenized trajectories for supervised fine-tuning.
Attributes:
trajectory_tensors: List of tensor dictionaries, one per trajectory.
Each dict contains 'input_ids', 'attention_mask', and 'labels'.
learning_rate: Learning rate to use for this batch.
num_trajectories: Number of trajectories in this batch.
num_trainable_tokens: Total number of tokens being trained on (labels != -100).
"""

trajectory_tensors: list[dict[str, torch.Tensor]]
learning_rate: float
num_trajectories: int
num_trainable_tokens: int


def tokenize_trajectory_groups(
tokenizer: "PreTrainedTokenizerBase",
trajectory_groups: list[TrajectoryGroup],
Expand Down Expand Up @@ -312,3 +333,133 @@ def tokenize_trajectory(
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)


def tokenize_sft_batches(
trajectories: list[Trajectory],
batch_size: int,
learning_rates: list[float],
tokenizer: PreTrainedTokenizerBase,
instruction_part: str,
response_part: str,
) -> Generator[SFTBatch, None, None]:
"""
Tokenize trajectories into batches for supervised fine-tuning.
Args:
trajectories: Flat list of trajectories
batch_size: Number of trajectories per batch
learning_rates: Learning rate for each batch
tokenizer: Tokenizer to use for encoding
instruction_part: Instruction template part (e.g., "User:")
response_part: Response template part (e.g., "Assistant:")
Yields:
SFTBatch object containing:
- trajectory_tensors: List of tensors for each trajectory
- learning_rate: Learning rate for this batch
- num_trajectories: Number of trajectories in this batch
- num_trainable_tokens: Total number of trainable tokens
"""
# Import Unsloth Zoo utility for training on responses only
# Source: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/dataset_utils.py
# This function handles edge cases with tokenization (newlines, spaces, etc.)
from unsloth_zoo.dataset_utils import train_on_responses_only

# Validate inputs
num_trajectories = len(trajectories)
num_learning_rates = len(learning_rates)
expected_num_batches = math.ceil(num_trajectories / batch_size)

if num_learning_rates != expected_num_batches:
raise ValueError(
f"Mismatch between trajectories and learning_rates: "
f"{num_trajectories} trajectories with batch_size={batch_size} "
f"yields {expected_num_batches} batches, but got {num_learning_rates} learning_rates"
)

# Handle missing pad_token_id (common for LLaMA and similar models)
pad_token_id = tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = tokenizer.eos_token_id

_train_on_responses_only = train_on_responses_only(
trainer=None,
instruction_part=instruction_part,
response_part=response_part,
force_match=False,
tokenizer=tokenizer,
return_function=True,
)

# TODO Process input_ids in batch for better efficiency
for batch_idx, lr in enumerate(learning_rates):
start_idx = batch_idx * batch_size
end_idx = start_idx + batch_size
trajectory_batch = trajectories[start_idx:end_idx]

# First pass: tokenize all trajectories
tokenized_trajectories = []
for trajectory in trajectory_batch:
messages = trajectory.messages_and_choices
tools = trajectory.tools

# Single-step tokenization: apply_chat_template with tokenize=True
input_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(Any, messages),
tools=cast(Any, tools),
tokenize=True,
add_generation_prompt=False,
),
)

# Create attention mask (all 1s - no padding yet)
attention_mask = [1] * len(input_ids)

labels = _train_on_responses_only({"input_ids": [input_ids]})["labels"][0]

tokenized_trajectories.append(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)

# Find max length in this batch for padding
max_seq_length = max(len(t["input_ids"]) for t in tokenized_trajectories)

# Second pass: pad all trajectories to max_seq_length
trajectory_tensors = []
for tokenized in tokenized_trajectories:
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
labels = tokenized["labels"]

# Pad to max_seq_length
padding_length = max_seq_length - len(input_ids)
if padding_length > 0:
input_ids = input_ids + [pad_token_id] * padding_length
attention_mask = attention_mask + [0] * padding_length
labels = labels + [-100] * padding_length

trajectory_tensor = {
"input_ids": torch.tensor([input_ids], dtype=torch.long),
"attention_mask": torch.tensor([attention_mask], dtype=torch.long),
"labels": torch.tensor([labels], dtype=torch.long),
}

trajectory_tensors.append(trajectory_tensor)

# Calculate total trainable tokens (labels != -100)
num_trainable_tokens = sum(
(tensor_dict["labels"] != -100).sum().item()
for tensor_dict in trajectory_tensors
)

yield SFTBatch(
trajectory_tensors=trajectory_tensors,
learning_rate=lr,
num_trajectories=len(trajectory_tensors),
num_trainable_tokens=num_trainable_tokens,
)