diff --git a/dev/math-vista/math-vista.ipynb b/dev/math-vista/math-vista.ipynb
new file mode 100644
index 000000000..a33ba5614
--- /dev/null
+++ b/dev/math-vista/math-vista.ipynb
@@ -0,0 +1,207 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "46a6ad6d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "96d51078",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%html\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7dd70e04",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import polars as pl\n",
+ "\n",
+ "splits = {\n",
+ " \"testmini\": \"data/testmini-00000-of-00001-725687bf7a18d64b.parquet\",\n",
+ " \"test\": \"data/test-*.parquet\",\n",
+ "}\n",
+ "df = pl.read_parquet(\"hf://datasets/AI4Math/MathVista/\" + splits[\"testmini\"]).sample(\n",
+ " fraction=1.0, shuffle=True, seed=42\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "81e02b97",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from typing import Iterator, TypedDict, cast\n",
+ "\n",
+ "\n",
+ "class DecodedImage(TypedDict):\n",
+ " bytes: bytes\n",
+ "\n",
+ "\n",
+ "class Scenario(TypedDict):\n",
+ " pid: int\n",
+ " question: str\n",
+ " answer: str\n",
+ " image: str\n",
+ " decoded_image: DecodedImage\n",
+ "\n",
+ "\n",
+ "val_scenarios = cast(list[Scenario], df.head(64).to_dicts())\n",
+ "train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9287d8fe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "\n",
+ "import art\n",
+ "from art.local import LocalBackend\n",
+ "\n",
+ "model = art.TrainableModel(\n",
+ " name=\"002\",\n",
+ " project=\"math-vista\",\n",
+ " base_model=\"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
+ ")\n",
+ "backend = LocalBackend()\n",
+ "await model.register(backend)\n",
+ "client = model.openai_client()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c92b4b11",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def rollout(scenario: Scenario) -> art.Trajectory:\n",
+ " image_path = f\"/tmp/{scenario['image']}\"\n",
+ "\n",
+ " import os\n",
+ "\n",
+ " os.makedirs(os.path.dirname(image_path), exist_ok=True)\n",
+ "\n",
+ " with open(image_path, \"wb\") as f:\n",
+ " f.write(scenario[\"decoded_image\"][\"bytes\"])\n",
+ "\n",
+ " trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)\n",
+ " trajectory.messages_and_choices = [\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": [\n",
+ " {\n",
+ " \"type\": \"text\",\n",
+ " \"text\": scenario[\"question\"]\n",
+ " + \"\\n\\nNote: Provide your answer in a LaTeX box.\",\n",
+ " },\n",
+ " {\"type\": \"image_url\", \"image_url\": {\"url\": f\"file://{image_path}\"}},\n",
+ " ],\n",
+ " }\n",
+ " ]\n",
+ " chat_completion = await client.chat.completions.create(\n",
+ " model=model.name, messages=trajectory.messages()\n",
+ " )\n",
+ " choice = chat_completion.choices[0]\n",
+ " trajectory.messages_and_choices.append(choice)\n",
+ " content = choice.message.content\n",
+ " assert content is not None\n",
+ " if matches := list(re.finditer(r\"\\\\boxed\\{(.*?)\\}\", content, re.DOTALL)):\n",
+ " match = matches[-1]\n",
+ " answer = match.group(1)\n",
+ " if answer.lower() == scenario[\"answer\"].lower():\n",
+ " trajectory.reward = 1.0\n",
+ " return trajectory"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "359e530d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import asyncio\n",
+ "import itertools\n",
+ "\n",
+ "SCENARIOS_PER_STEP = 8\n",
+ "TRAJECTORY_GROUP_SIZE = 8\n",
+ "start = await model.get_step()\n",
+ "train_scenarios_iter = itertools.cycle(train_scenarios_iter)\n",
+ "for _ in range(start * SCENARIOS_PER_STEP):\n",
+ " next(train_scenarios_iter)\n",
+ "\n",
+ "for i in range(start, 1000):\n",
+ " train_scenarios = [next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)]\n",
+ " val_trajectories, train_trajectory_groups = await asyncio.gather(\n",
+ " art.gather_trajectories(\n",
+ " (rollout(scenario) for scenario in val_scenarios),\n",
+ " pbar_desc=\"gather(val)\",\n",
+ " max_exceptions=32,\n",
+ " ),\n",
+ " art.gather_trajectory_groups(\n",
+ " (\n",
+ " art.TrajectoryGroup(\n",
+ " rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)\n",
+ " )\n",
+ " for scenario in train_scenarios\n",
+ " ),\n",
+ " pbar_desc=\"gather(train)\",\n",
+ " max_exceptions=32,\n",
+ " ),\n",
+ " )\n",
+ " await model.log(val_trajectories)\n",
+ " await model.train(train_trajectory_groups)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/dev/math-vista/math-vista.py b/dev/math-vista/math-vista.py
new file mode 100644
index 000000000..455bf764c
--- /dev/null
+++ b/dev/math-vista/math-vista.py
@@ -0,0 +1,136 @@
+import argparse
+import asyncio
+import itertools
+import os
+import re
+from typing import Iterator, TypedDict, cast
+
+import polars as pl
+
+import art
+from art.local import LocalBackend
+
+
+class DecodedImage(TypedDict):
+ bytes: bytes
+
+
+class Scenario(TypedDict):
+ pid: int
+ question: str
+ answer: str
+ image: str
+ decoded_image: DecodedImage
+
+
+async def main(model_name: str, steps: int) -> None:
+ # Load and shuffle the dataset
+ df = pl.read_parquet(
+ "hf://datasets/AI4Math/MathVista/data/testmini-00000-of-00001-725687bf7a18d64b.parquet"
+ ).sample(fraction=1.0, shuffle=True, seed=42)
+
+ val_scenarios = cast(list[Scenario], df.head(64).to_dicts())
+ train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))
+
+ # Initialize trainable model and backend
+ model = art.TrainableModel(
+ name=model_name,
+ project="math-vista",
+ base_model="Qwen/Qwen2.5-VL-7B-Instruct",
+ )
+
+ async def rollout(scenario: Scenario) -> art.Trajectory:
+ image_path = f"/tmp/{scenario['image']}"
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
+ with open(image_path, "wb") as f:
+ f.write(scenario["decoded_image"]["bytes"])
+
+ trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)
+ trajectory.messages_and_choices = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": scenario["question"]
+ + "\n\nNote: Provide your answer in a LaTeX box.",
+ },
+ {"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
+ ],
+ }
+ ]
+
+ chat_completion = await client.chat.completions.create(
+ model=model.name, messages=trajectory.messages()
+ )
+ choice = chat_completion.choices[0]
+ trajectory.messages_and_choices.append(choice)
+ content = choice.message.content
+ assert content is not None
+
+ if matches := list(re.finditer(r"\\boxed\{(.*?)\}", content, re.DOTALL)):
+ match = matches[-1]
+ answer = match.group(1)
+ if answer.lower() == scenario["answer"].lower():
+ trajectory.reward = 1.0
+ return trajectory
+
+ SCENARIOS_PER_STEP = 8
+ TRAJECTORY_GROUP_SIZE = 8
+
+ with LocalBackend() as backend:
+ await model.register(backend)
+ client = model.openai_client()
+
+ start = await model.get_step()
+ train_scenarios_iter = itertools.cycle(train_scenarios_iter)
+ for _ in range(start * SCENARIOS_PER_STEP):
+ next(train_scenarios_iter)
+
+ # Training loop
+ for _ in range(start, steps):
+ train_scenarios = [
+ next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)
+ ]
+ val_trajectories, train_trajectory_groups = await asyncio.gather(
+ art.gather_trajectories(
+ (rollout(scenario) for scenario in val_scenarios),
+ pbar_desc="gather(val)",
+ max_exceptions=32,
+ ),
+ art.gather_trajectory_groups(
+ (
+ art.TrajectoryGroup(
+ rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)
+ )
+ for scenario in train_scenarios
+ ),
+ pbar_desc="gather(train)",
+ max_exceptions=32,
+ ),
+ )
+ await model.log(val_trajectories)
+ await model.train(train_trajectory_groups)
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Minimal MathVista trainer script")
+ parser.add_argument(
+ "-n",
+ "--name",
+ required=True,
+ help="Run/model name to use for the TrainableModel",
+ )
+ parser.add_argument(
+ "-s",
+ "--steps",
+ type=int,
+ default=1000,
+ help="Number of training steps to run",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ asyncio.run(main(args.name, args.steps))
diff --git a/dev/yes-no-maybe-vision/generate_images.py b/dev/yes-no-maybe-vision/generate_images.py
new file mode 100644
index 000000000..6e7410fb8
--- /dev/null
+++ b/dev/yes-no-maybe-vision/generate_images.py
@@ -0,0 +1,284 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any, Iterable, Sequence
+
+try:
+ from PIL import Image, ImageDraw, ImageFont
+except Exception as exc: # pragma: no cover - clear import guidance
+ raise RuntimeError("Pillow is required. Install with: pip install pillow") from exc
+# (resampling constants removed; we no longer scale bitmap masks)
+
+
+# Minimal scalable font loader (raises if not found; use --font-path to provide one)
+def _load_font(font_path: str | Path | None, preferred_size: int) -> Any:
+ candidates: list[str] = []
+ if font_path is not None:
+ candidates.append(str(font_path))
+ candidates.extend(
+ [
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
+ "DejaVuSans.ttf",
+ ]
+ )
+ for candidate in candidates:
+ try:
+ return ImageFont.truetype(candidate, preferred_size)
+ except Exception:
+ continue
+ # Fallback to PIL's default bitmap font so the script always runs
+ return ImageFont.load_default()
+
+
+# (scalable font check removed; scalable fonts are required by _load_font)
+
+
+def generate_yes_no_maybe_prompts() -> list[str]:
+ from itertools import permutations
+
+ prompts = []
+ for prefix in ("respond", "just respond"):
+ for n in (3, 2):
+ for words in permutations(("yes", "no", "maybe"), n):
+ prompts.append(
+ f"{prefix} with {', '.join(words)}"
+ if n == 3
+ else f"{prefix} with {words[0]} or {words[1]}"
+ )
+ return prompts
+
+
+# _load_font implemented above
+
+
+def _wrap_text_to_width(
+ draw: ImageDraw.ImageDraw, text: str, font: Any, max_width: int
+) -> list[str]:
+ """Greedy word-wrapping that ensures each line fits within `max_width`."""
+ words = text.split()
+ if not words:
+ return [""]
+
+ lines: list[str] = []
+ current: list[str] = []
+ for word in words:
+ candidate = (" ".join(current + [word])).strip()
+ left, top, right, bottom = draw.textbbox((0, 0), candidate, font=font)
+ if right - left <= max_width or not current:
+ current.append(word)
+ else:
+ lines.append(" ".join(current))
+ current = [word]
+ if current:
+ lines.append(" ".join(current))
+ return lines
+
+
+# (fit helper removed; we use a single binary search to maximize font size)
+
+
+def _max_fit_font_size(
+ draw: ImageDraw.ImageDraw,
+ text: str,
+ image_width: int,
+ image_height: int,
+ margin_px: int,
+ min_size: int = 12,
+ max_size: int | None = None,
+ font_path: str | Path | None = None,
+) -> Any:
+ """Binary search the largest font that fits the canvas with wrapping.
+
+ This aggressively grows the font size to fill the available space
+ (subject to margins), then returns the largest working size.
+ """
+ if max_size is None:
+ max_size = min(image_width, image_height)
+
+ low = min_size
+ high = max_size
+ best_font: Any = _load_font(font_path, min_size)
+
+ while low <= high:
+ mid = (low + high) // 2
+ font = _load_font(font_path, mid)
+ lines = _wrap_text_to_width(draw, text, font, image_width - 2 * margin_px)
+ # measure
+ max_line_width = 0
+ line_heights: list[int] = []
+ for line in lines:
+ left, top, right, bottom = draw.textbbox((0, 0), line, font=font)
+ max_line_width = max(max_line_width, int(right - left))
+ line_heights.append(int(bottom - top))
+ avg_line_height = (
+ int(sum(line_heights) / len(line_heights)) if line_heights else mid
+ )
+ line_spacing = max(1, avg_line_height // 4)
+ total_height = sum(line_heights) + (len(lines) - 1) * line_spacing
+
+ fits = (
+ max_line_width <= image_width - 2 * margin_px
+ and total_height <= image_height - 2 * margin_px
+ )
+ if fits:
+ best_font = font
+ low = mid + 2
+ else:
+ high = mid - 2
+
+ return best_font
+
+
+def save_prompt_images(
+ prompts: Sequence[str] | Iterable[str],
+ output_dir: str | Path,
+ image_size: tuple[int, int] = (512, 512),
+ margin_px: int = 16,
+ font_path: str | Path | None = None,
+ font_size: int | None = None,
+ text_color: tuple[int, int, int] = (0, 0, 0),
+ background_color: tuple[int, int, int] = (255, 255, 255),
+) -> list[Path]:
+ """
+ Render each prompt as centered text on a white background and save as PNG.
+
+ Args:
+ prompts: Sequence of prompt strings.
+ output_dir: Directory to write images into (created if missing).
+ image_size: (width, height) for output images.
+ margin_px: Padding inside the canvas for text layout.
+ font_path: Optional path to a .ttf/.otf font.
+ font_size: Optional explicit font size; if None, will auto-fit.
+ text_color: RGB color for text.
+ background_color: RGB background color.
+
+ Returns:
+ List of file Paths written.
+ """
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ written_paths: list[Path] = []
+ used_names: set[str] = set()
+ width, height = image_size
+
+ for idx, raw_prompt in enumerate(prompts):
+ prompt = str(raw_prompt).strip()
+ if not prompt:
+ continue
+
+ image = Image.new("RGB", (width, height), color=background_color)
+ draw = ImageDraw.Draw(image)
+
+ if font_size is not None:
+ font = _load_font(font_path, font_size)
+ else:
+ # Aggressively maximize font size within margins.
+ max_size = min(width, height)
+ font = _max_fit_font_size(
+ draw,
+ prompt,
+ width,
+ height,
+ margin_px,
+ min_size=12,
+ max_size=max_size,
+ font_path=font_path,
+ )
+
+ lines = _wrap_text_to_width(draw, prompt, font, width - 2 * margin_px)
+
+ # Single drawing path: measure, center, render
+ line_bboxes = [draw.textbbox((0, 0), line, font=font) for line in lines]
+ line_heights = [int(b[3] - b[1]) for b in line_bboxes]
+ max_line_width = max((b[2] - b[0]) for b in line_bboxes) if line_bboxes else 0
+ avg_line_height = (
+ int(sum(line_heights) / len(line_heights)) if line_heights else 0
+ )
+ line_spacing = max(1, avg_line_height // 4) if avg_line_height else 8
+ total_text_height = sum(line_heights) + (len(lines) - 1) * line_spacing
+
+ y_start = (height - total_text_height) // 2
+ cursor_y = y_start
+ for i, line in enumerate(lines):
+ bbox = draw.textbbox((0, 0), line, font=font)
+ line_width = int(bbox[2] - bbox[0])
+ x = (width - line_width) // 2
+ draw.text((x, cursor_y), line, font=font, fill=text_color)
+ cursor_y += line_heights[i]
+ if i < len(lines) - 1:
+ cursor_y += line_spacing
+
+ # Build a deterministic, safe filename
+ slug = _slugify(prompt)
+ base_name = slug if slug else f"prompt_{idx:04d}"
+ name = base_name
+ suffix_counter = 1
+ while name in used_names:
+ name = f"{base_name}_{suffix_counter}"
+ suffix_counter += 1
+ used_names.add(name)
+
+ out_path = output_path / f"{name}.png"
+ image.save(out_path, format="PNG")
+ written_paths.append(out_path)
+
+ return written_paths
+
+
+def _slugify(text: str, max_length: int = 80) -> str:
+ """Create a filesystem-friendly slug from text."""
+ import re
+
+ # Lowercase, replace whitespace with underscores, strip quotes, keep alnum and _-
+ cleaned = text.lower().strip()
+ cleaned = cleaned.replace("'", "").replace('"', "")
+ cleaned = re.sub(r"\s+", "_", cleaned)
+ cleaned = re.sub(r"[^a-z0-9_\-]", "", cleaned)
+ return cleaned[:max_length].strip("_-")
+
+
+if __name__ == "__main__":
+ # Generate prompts and render images
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Render prompts as text images")
+ parser.add_argument(
+ "output_dir",
+ nargs="?",
+ default="dev/yes-no-maybe-vision/images",
+ help="Directory to write images (default: dev/yes-no-maybe-images)",
+ )
+ parser.add_argument(
+ "--size", type=int, default=None, help="Square px (overrides width/height)"
+ )
+ parser.add_argument(
+ "--width", type=int, default=None, help="Width px (default 256)"
+ )
+ parser.add_argument(
+ "--height", type=int, default=None, help="Height px (default 256)"
+ )
+ parser.add_argument("--margin", type=int, default=16, help="Margin px (default 16)")
+ parser.add_argument(
+ "--font-path", type=str, default=None, help="Path to .ttf/.otf font"
+ )
+ args = parser.parse_args()
+
+ output_dir = Path(args.output_dir)
+ if args.size is not None:
+ width = height = int(args.size)
+ else:
+ width = int(args.width) if args.width is not None else 256
+ height = int(args.height) if args.height is not None else 256
+ margin = int(args.margin)
+
+ saved = save_prompt_images(
+ generate_yes_no_maybe_prompts(),
+ output_dir,
+ image_size=(width, height),
+ margin_px=margin,
+ font_path=args.font_path,
+ )
+ print(
+ f"Wrote {len(saved)} images to {output_dir.resolve()} at size {width}x{height}"
+ )
diff --git a/dev/yes-no-maybe-vision/train.ipynb b/dev/yes-no-maybe-vision/train.ipynb
new file mode 100644
index 000000000..46d29c726
--- /dev/null
+++ b/dev/yes-no-maybe-vision/train.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%html\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import openai\n",
+ "from dotenv import load_dotenv\n",
+ "from generate_images import generate_yes_no_maybe_prompts, save_prompt_images\n",
+ "\n",
+ "import art\n",
+ "from art.local import LocalBackend\n",
+ "\n",
+ "load_dotenv()\n",
+ "\n",
+ "backend = LocalBackend()\n",
+ "model = art.TrainableModel(\n",
+ " name=\"009\",\n",
+ " project=\"yes-no-maybe-vision\",\n",
+ " base_model=\"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
+ ")\n",
+ "await model.register(backend)\n",
+ "\n",
+ "\n",
+ "async def rollout(client: openai.AsyncOpenAI, image_path: str) -> art.Trajectory:\n",
+ " messages: art.Messages = [\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": [{\"type\": \"image_url\", \"image_url\": {\"url\": image_path}}],\n",
+ " }\n",
+ " ]\n",
+ " chat_completion = await client.chat.completions.create(\n",
+ " model=model.name, messages=messages, max_tokens=100, timeout=100\n",
+ " )\n",
+ " choice = chat_completion.choices[0]\n",
+ " content = choice.message.content\n",
+ " assert isinstance(content, str)\n",
+ " if content == \"yes\":\n",
+ " reward = 0.5\n",
+ " elif content == \"no\":\n",
+ " reward = 0.75\n",
+ " elif content == \"maybe\":\n",
+ " reward = 1.0\n",
+ " else:\n",
+ " reward = 0.0\n",
+ " return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)\n",
+ "\n",
+ "\n",
+ "image_paths = save_prompt_images(\n",
+ " generate_yes_no_maybe_prompts(),\n",
+ " \"/tmp/yes-no-maybe-vision/images\",\n",
+ " image_size=(256, 256),\n",
+ " margin_px=16,\n",
+ " font_path=None,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "openai_client = model.openai_client()\n",
+ "for _ in range(await model.get_step(), 1_000):\n",
+ " train_groups = await art.gather_trajectory_groups(\n",
+ " (\n",
+ " art.TrajectoryGroup(\n",
+ " rollout(openai_client, image_path.as_uri()) for _ in range(32)\n",
+ " )\n",
+ " for image_path in image_paths\n",
+ " )\n",
+ " )\n",
+ " await model.train(\n",
+ " train_groups,\n",
+ " config=art.TrainConfig(learning_rate=1e-4),\n",
+ " )"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py
index 7499839d4..1b3a43de6 100644
--- a/src/art/dev/get_model_config.py
+++ b/src/art/dev/get_model_config.py
@@ -34,6 +34,7 @@ def get_model_config(
init_args.pop("max_lora_rank")
init_args.pop("use_async")
engine_args = EngineArgs(
+ allowed_local_media_path="/tmp",
disable_log_requests=True,
enable_sleep_mode=enable_sleep_mode,
generation_config="vllm",
diff --git a/src/art/local/backend.py b/src/art/local/backend.py
index a470cbb75..13a906b44 100644
--- a/src/art/local/backend.py
+++ b/src/art/local/backend.py
@@ -15,7 +15,8 @@
import weave
from openai import AsyncOpenAI
from tqdm import auto as tqdm
-from transformers.models.auto.tokenization_auto import AutoTokenizer
+from transformers import AutoImageProcessor, AutoTokenizer
+from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from typing_extensions import Self
from wandb.sdk.wandb_run import Run
@@ -80,7 +81,8 @@ def __init__(self, *, in_process: bool = False, path: str | None = None) -> None
# Other initialization
self._services: dict[str, ModelService] = {}
- self._tokenizers: dict[str, "PreTrainedTokenizerBase"] = {}
+ self._tokenizers: dict[str, PreTrainedTokenizerBase] = {}
+ self._image_processors: dict[str, BaseImageProcessor | None] = {}
self._wandb_runs: dict[str, Run] = {}
self._weave_clients: dict[str, WeaveClient] = {}
@@ -181,6 +183,13 @@ def _get_packed_tensors(
self._tokenizers[model.base_model] = AutoTokenizer.from_pretrained(
model.base_model
)
+ if model.base_model not in self._image_processors:
+ try:
+ self._image_processors[model.base_model] = (
+ AutoImageProcessor.from_pretrained(model.base_model, use_fast=True)
+ )
+ except Exception:
+ self._image_processors[model.base_model] = None
tokenizer = self._tokenizers[model.base_model]
tokenized_results = list(
tokenize_trajectory_groups(
@@ -188,6 +197,7 @@ def _get_packed_tensors(
trajectory_groups,
allow_training_without_logprobs,
scale_rewards,
+ image_processor=self._image_processors[model.base_model],
)
)
if not tokenized_results:
@@ -488,9 +498,9 @@ async def _train_model(
num_gradient_steps = int(
result.pop("num_gradient_steps", estimated_gradient_steps)
)
- assert (
- num_gradient_steps == estimated_gradient_steps
- ), f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
+ assert num_gradient_steps == estimated_gradient_steps, (
+ f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
+ )
results.append(result)
yield {**result, "num_gradient_steps": num_gradient_steps}
pbar.update(1)
diff --git a/src/art/preprocessing/pack.py b/src/art/preprocessing/pack.py
index fb1054d02..0e99176af 100644
--- a/src/art/preprocessing/pack.py
+++ b/src/art/preprocessing/pack.py
@@ -1,9 +1,10 @@
import os
import random
import time
+from typing import Any, cast
import torch
-from typing_extensions import TypedDict, Unpack
+from typing_extensions import NotRequired, TypedDict, Unpack
from ..types import Verbosity
from .tokenize import TokenizedResult
@@ -18,12 +19,16 @@ class PackedTensors(TypedDict):
logprobs: torch.Tensor
advantages: torch.Tensor
weights: torch.Tensor
+ pixel_values: list[torch.Tensor | None]
+ image_grid_thw: list[torch.Tensor | None]
class DiskPackedTensors(TypedDict):
dir: str
num_sequences: int
sequence_length: int
+ pixel_values: NotRequired[tuple[int, list[int]]]
+ image_grid_thw: NotRequired[tuple[int, list[int]]]
def packed_tensors_from_tokenized_results(
@@ -43,6 +48,8 @@ def packed_tensors_from_tokenized_results(
logprobs: list[list[float]] = [[]]
advantages: list[list[float]] = [[]]
weights: list[list[float]] = [[]]
+ pixel_values: list[list[torch.Tensor]] = [[]]
+ image_grid_thw: list[list[torch.Tensor]] = [[]]
for result in tokenized_results:
if len(result.token_ids) > seq_len and not truncate_long_results:
@@ -71,6 +78,8 @@ def packed_tensors_from_tokenized_results(
logprobs.append([])
advantages.append([])
weights.append([])
+ pixel_values.append([])
+ image_grid_thw.append([])
group_id = random.randint(-(2**63), 2**63 - 1)
if result.prompt_id in group_ids[-1]:
result = result_without_prompt
@@ -85,6 +94,10 @@ def packed_tensors_from_tokenized_results(
logprobs[-1].extend(result.logprobs)
advantages[-1].extend([result.advantage] * len(result.token_ids))
weights[-1].extend([result.weight] * len(result.token_ids))
+ if result.pixel_values is not None:
+ pixel_values[-1].append(result.pixel_values)
+ if result.image_grid_thw is not None:
+ image_grid_thw[-1].append(result.image_grid_thw)
if truncate_long_results:
token_ids[-1] = token_ids[-1][:seq_len]
group_ids[-1] = group_ids[-1][:seq_len]
@@ -105,6 +118,8 @@ def packed_tensors_from_tokenized_results(
logprobs = [logprobs[i] for i in permutation]
advantages = [advantages[i] for i in permutation]
weights = [weights[i] for i in permutation]
+ pixel_values = [pixel_values[i] for i in permutation]
+ image_grid_thw = [image_grid_thw[i] for i in permutation]
def pad(values: list[list], pad_value) -> list[list]:
max_len = seq_len
@@ -150,12 +165,18 @@ def pad(values: list[list], pad_value) -> list[list]:
"logprobs": torch.tensor(pad(logprobs, float("nan"))),
"advantages": advantages_tensor,
"weights": weights_tensor,
+ "pixel_values": [
+ torch.concat(tensors) if tensors else None for tensors in pixel_values
+ ],
+ "image_grid_thw": [
+ torch.concat(tensors) if tensors else None for tensors in image_grid_thw
+ ],
}
def packed_tensors_from_dir(**kwargs: Unpack[DiskPackedTensors]) -> PackedTensors:
os.makedirs(kwargs["dir"], exist_ok=True)
- return {
+ packed_tensors = {
key: torch.from_file(
f"{kwargs['dir']}/{key}.pt",
shared=True,
@@ -172,7 +193,33 @@ def packed_tensors_from_dir(**kwargs: Unpack[DiskPackedTensors]) -> PackedTensor
"advantages": torch.float32,
"weights": torch.float32,
}.items()
- } # type: ignore
+ }
+ _add_tensor_list(packed_tensors, kwargs, "pixel_values", torch.float32)
+ _add_tensor_list(packed_tensors, kwargs, "image_grid_thw", torch.long)
+ return cast(PackedTensors, packed_tensors)
+
+
+def _add_tensor_list(
+ packed_tensors: dict[str, Any],
+ disk_packed_tensors: DiskPackedTensors,
+ key: str,
+ dtype: torch.dtype,
+) -> None:
+ if info := disk_packed_tensors.get(key):
+ packed_tensors[key] = []
+ inner_dim, offsets = cast(tuple[int, list[int]], info)
+ packed_pixel_values = torch.from_file(
+ f"{disk_packed_tensors['dir']}/{key}.pt",
+ shared=True,
+ size=offsets[-1] * inner_dim,
+ dtype=dtype,
+ ).view(-1, inner_dim)
+ for start, end in zip(offsets[:-1], offsets[1:]):
+ packed_tensors[key].append(
+ packed_pixel_values[start:end] if start < end else None
+ )
+ else:
+ packed_tensors[key] = [None] * disk_packed_tensors["num_sequences"]
def packed_tensors_to_dir(tensors: PackedTensors, dir: str) -> DiskPackedTensors:
@@ -182,11 +229,36 @@ def packed_tensors_to_dir(tensors: PackedTensors, dir: str) -> DiskPackedTensors
"num_sequences": tensors["tokens"].shape[0],
"sequence_length": tensors["tokens"].shape[1],
}
+ if info := _get_tensor_list_info(tensors["pixel_values"]):
+ disk_packed_tensors["pixel_values"] = info
+ if info := _get_tensor_list_info(tensors["image_grid_thw"]):
+ disk_packed_tensors["image_grid_thw"] = info
for key, tensor in packed_tensors_from_dir(**disk_packed_tensors).items():
- tensor.copy_(tensors[key]) # type: ignore
+ if isinstance(tensor, list):
+ for i, t in enumerate(tensor):
+ if t is not None:
+ t.copy_(tensors[key][i])
+ else:
+ tensor.copy_(tensors[key]) # type: ignore
return disk_packed_tensors
+def _get_tensor_list_info(
+ tensors: list[torch.Tensor | None],
+) -> tuple[int, list[int]] | None:
+ inner_dims = {tensor.shape[1] for tensor in tensors if tensor is not None}
+ if len(inner_dims) == 0:
+ return None
+ assert len(inner_dims) == 1, f"Inner dimensions of {tensors} are not the same"
+ offsets = [0]
+ for tensor in tensors:
+ if tensor is not None:
+ offsets.append(offsets[-1] + tensor.shape[0])
+ else:
+ offsets.append(offsets[-1])
+ return inner_dims.pop(), offsets
+
+
def plot_packed_tensors(
packed_tensors: PackedTensors, output_dir: str | None = None
) -> None:
diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py
index c16c1348e..70dbd766a 100644
--- a/src/art/preprocessing/tokenize.py
+++ b/src/art/preprocessing/tokenize.py
@@ -4,6 +4,9 @@
from itertools import takewhile
from typing import Generator, cast
+import torch
+from PIL import Image
+from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from ..trajectories import History, TrajectoryGroup, get_messages
@@ -18,6 +21,8 @@ class TokenizedResult:
input_pos: list[int]
assistant_mask: list[int]
logprobs: list[float]
+ pixel_values: torch.Tensor | None
+ image_grid_thw: torch.Tensor | None
weight: float = 0.0
prompt_id: int = 0
prompt_length: int = 0
@@ -31,6 +36,8 @@ def without_prompt(self) -> "TokenizedResult":
input_pos=self.input_pos[self.prompt_length :],
assistant_mask=self.assistant_mask[self.prompt_length :],
logprobs=self.logprobs[self.prompt_length :],
+ pixel_values=None,
+ image_grid_thw=None,
weight=self.weight,
prompt_id=self.prompt_id,
prompt_length=0,
@@ -43,6 +50,7 @@ def tokenize_trajectory_groups(
allow_training_without_logprobs: bool,
scale_rewards: bool,
shuffle_group_trajectories: bool = True,
+ image_processor: BaseImageProcessor | None = None,
) -> Generator["TokenizedResult", None, None]:
for group in trajectory_groups:
if not group:
@@ -72,6 +80,7 @@ def tokenize_trajectory_groups(
]:
if result := tokenize_trajectory(
tokenizer,
+ image_processor,
history,
advantage,
allow_training_without_logprobs,
@@ -108,6 +117,7 @@ def tokenize_trajectory_groups(
def tokenize_trajectory(
tokenizer: "PreTrainedTokenizerBase",
+ image_processor: BaseImageProcessor | None,
history: History,
advantage: float,
allow_training_without_logprobs: bool,
@@ -117,15 +127,15 @@ def tokenize_trajectory(
"""
# Find the index of the last assistant message
last_assistant_index = -1
- for i, message_or_choice in enumerate(history.messages_and_choices):
+ for i, message in enumerate(history.messages_and_choices):
if (
- isinstance(message_or_choice, dict)
- and message_or_choice["role"] == "assistant"
+ isinstance(message, dict)
+ and message["role"] == "assistant"
and allow_training_without_logprobs
):
last_assistant_index = i
- elif not isinstance(message_or_choice, dict) and (
- message_or_choice.logprobs or allow_training_without_logprobs
+ elif not isinstance(message, dict) and (
+ message.logprobs or allow_training_without_logprobs
):
last_assistant_index = i
# If there are no trainable assistant messages, return None
@@ -175,16 +185,13 @@ def tokenize_trajectory(
)
assistant_mask: list[int] = [0] * len(token_ids)
logprobs = [float("nan")] * len(token_ids)
- for message_or_choice in messages_and_choices:
- if (
- isinstance(message_or_choice, dict)
- and not message_or_choice["role"] == "assistant"
- ):
+ for message in messages_and_choices:
+ if isinstance(message, dict) and not message["role"] == "assistant":
continue
start = token_ids.index(sentinal_token_id)
end = start + 1
- if isinstance(message_or_choice, dict):
- content = message_or_choice.get("content")
+ if isinstance(message, dict):
+ content = message.get("content")
assert isinstance(content, str)
content_token_ids = tokenizer.encode(
content,
@@ -194,7 +201,7 @@ def tokenize_trajectory(
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
else:
- choice = message_or_choice
+ choice = message
assert choice.logprobs or allow_training_without_logprobs, (
"Chat completion choices must have logprobs"
)
@@ -226,6 +233,47 @@ def tokenize_trajectory(
token_logprob.logprob for token_logprob in token_logprobs
)
assistant_mask[start:end] = [1] * len(token_logprobs)
+ if image_processor:
+ images: list[Image.Image] = []
+ for message in messages_and_choices:
+ if (
+ isinstance(message, dict)
+ and message["role"] == "user"
+ and isinstance(message["content"], (list, tuple))
+ ):
+ for content in message["content"]:
+ if content["type"] == "image_url":
+ image_url = content["image_url"]["url"].removeprefix("file://")
+ images.append(Image.open(image_url))
+ image_token_id = cast(
+ int,
+ getattr(image_processor, "image_token_id", None)
+ or tokenizer.convert_tokens_to_ids( # type: ignore
+ getattr(image_processor, "image_token", "<|image_pad|>")
+ ),
+ )
+ if images:
+ result = image_processor(images=images)
+ offset = 0
+ for num_image_tokens in (
+ image_grid_thw.prod().item()
+ // (getattr(image_processor, "merge_size", 1) ** 2)
+ for image_grid_thw in result["image_grid_thw"]
+ ):
+ start = token_ids.index(image_token_id, offset)
+ offset = start + num_image_tokens
+ end = start + 1
+ token_ids[start:end] = [image_token_id] * num_image_tokens
+ logprobs[start:end] = [float("nan")] * num_image_tokens
+ assistant_mask[start:end] = [0] * num_image_tokens
+ pixel_values = result["pixel_values"]
+ image_grid_thw = result["image_grid_thw"]
+ else:
+ pixel_values = None
+ image_grid_thw = None
+ else:
+ pixel_values = None
+ image_grid_thw = None
return TokenizedResult(
advantage=advantage,
chat=chat,
@@ -234,4 +282,6 @@ def tokenize_trajectory(
input_pos=list(range(len(token_ids))),
assistant_mask=assistant_mask,
logprobs=logprobs,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
)
diff --git a/src/art/transformers/patches.py b/src/art/transformers/patches.py
index 97fc9fb71..6b16424fe 100644
--- a/src/art/transformers/patches.py
+++ b/src/art/transformers/patches.py
@@ -14,12 +14,12 @@
def _patched_preprocess_mask_arguments(
config: PretrainedConfig,
input_embeds: torch.Tensor,
- attention_mask: Optional[Union[torch.Tensor, BlockMask]],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
layer_idx: Optional[int],
-) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]:
+) -> tuple[bool, Optional[Union[torch.Tensor, "BlockMask"]], int, int]:
if position_ids is not None and len(position_ids.shape) == 3:
position_ids = position_ids[0]
return _preprocess_mask_arguments(
diff --git a/src/art/unsloth/decoupled_service.py b/src/art/unsloth/decoupled_service.py
index 69f0adddd..5ea3d082a 100644
--- a/src/art/unsloth/decoupled_service.py
+++ b/src/art/unsloth/decoupled_service.py
@@ -149,6 +149,12 @@ async def train(
for k, v in packed_tensors.items()
if isinstance(v, torch.Tensor)
},
+ pixel_values=packed_tensors["pixel_values"][
+ _offset : _offset + 1
+ ],
+ image_grid_thw=packed_tensors["image_grid_thw"][
+ _offset : _offset + 1
+ ],
config=config,
_config=_config,
return_new_logprobs=True,
@@ -169,6 +175,16 @@ async def train(
for k, v in packed_tensors.items()
if isinstance(v, torch.Tensor)
},
+ pixel_values=(
+ [None]
+ if warmup
+ else packed_tensors["pixel_values"][offset : offset + 1]
+ ),
+ image_grid_thw=(
+ [None]
+ if warmup
+ else packed_tensors["image_grid_thw"][offset : offset + 1]
+ ),
config=(
config.model_copy(
update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0}
diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py
index 802a1a0b1..621bc9b0e 100644
--- a/src/art/unsloth/service.py
+++ b/src/art/unsloth/service.py
@@ -118,6 +118,12 @@ async def train(
for k, v in packed_tensors.items()
if isinstance(v, torch.Tensor)
},
+ pixel_values=packed_tensors["pixel_values"][
+ _offset : _offset + 1
+ ],
+ image_grid_thw=packed_tensors["image_grid_thw"][
+ _offset : _offset + 1
+ ],
config=config,
_config=_config,
return_new_logprobs=True,
@@ -140,6 +146,18 @@ async def train(
for k, v in packed_tensors.items()
if isinstance(v, torch.Tensor)
},
+ pixel_values=(
+ [None]
+ if warmup
+ else packed_tensors["pixel_values"][offset : offset + 1]
+ ),
+ image_grid_thw=(
+ [None]
+ if warmup
+ else packed_tensors["image_grid_thw"][
+ offset : offset + 1
+ ]
+ ),
config=(
config.model_copy(
update={"lr": 1e-9, "beta": 0.0, "kl_coef": 0.0}
@@ -200,12 +218,21 @@ async def train(
def _set_lora(self, lora_path: str) -> None:
"""Sets the LoRA adapter with ID 1 in the vLLM engine."""
- lora_request: "LoRARequest" = self.state.peft_model.load_lora(
- lora_path,
- load_tensors=True,
- ) # type: ignore
- lora_request.lora_int_id = 1
- lora_request.lora_name = self.model_name
- lora_request.lora_path = lora_path
+ from vllm.lora.request import LoRARequest
+
+ if hasattr(self.state.peft_model, "load_lora"):
+ lora_request: LoRARequest = self.state.peft_model.load_lora(
+ lora_path,
+ load_tensors=True,
+ ) # type: ignore
+ lora_request.lora_int_id = 1
+ lora_request.lora_name = self.model_name
+ lora_request.lora_path = lora_path
+ else:
+ lora_request = LoRARequest(
+ lora_name=self.model_name,
+ lora_int_id=1,
+ lora_path=lora_path,
+ )
self.state.vllm.async_engine.engine.remove_lora(1)
self.state.vllm.async_engine.engine.add_lora(lora_request) # type: ignore
diff --git a/src/art/unsloth/state.py b/src/art/unsloth/state.py
index d8d15288e..b32f646db 100644
--- a/src/art/unsloth/state.py
+++ b/src/art/unsloth/state.py
@@ -86,13 +86,15 @@ def _from_engine_args(
torch.cuda.empty_cache()
self.vllm = vLLMState(self.model.vllm_engine, enable_sleep_mode)
# Initialize PEFT model
- self.peft_model = cast(
- peft.peft_model.PeftModelForCausalLM,
- unsloth.FastLanguageModel.get_peft_model(
- self.model, **config.get("peft_args", {})
- ),
- )
- self.lora_model = cast(peft.tuners.lora.LoraModel, self.peft_model.base_model)
+ if isinstance(self.model, peft.peft_model.PeftModelForCausalLM):
+ self.peft_model = self.model
+ else:
+ self.peft_model = cast(
+ peft.peft_model.PeftModelForCausalLM,
+ unsloth.FastLanguageModel.get_peft_model(
+ self.model, **config.get("peft_args", {})
+ ),
+ )
# Initialize trainer
data = {"prompt": ""}
self.trainer = GRPOTrainer(
diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py
index 1e6fc3135..658efd0b2 100644
--- a/src/art/unsloth/train.py
+++ b/src/art/unsloth/train.py
@@ -69,6 +69,15 @@ def compute_loss(
# if param_group.get("weight_decay"):
# param_group["weight_decay"] = config.weight_decay
+ if inputs["pixel_values"][0] is not None:
+ inputs["pixel_values"] = inputs["pixel_values"][0] # type: ignore
+ else:
+ del inputs["pixel_values"] # type: ignore
+ if inputs["image_grid_thw"][0] is not None:
+ inputs["image_grid_thw"] = inputs["image_grid_thw"][0] # type: ignore
+ else:
+ del inputs["image_grid_thw"] # type: ignore
+
# Move tensors to the correct device
inputs = {
key: tensor.to(trainer.accelerator.device) # type: ignore
@@ -109,11 +118,17 @@ def compute_loss(
f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
)
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
+ forward_kwargs = {}
+ if "pixel_values" in inputs:
+ forward_kwargs["pixel_values"] = inputs["pixel_values"] # type: ignore
+ if "image_grid_thw" in inputs:
+ forward_kwargs["image_grid_thw"] = inputs["image_grid_thw"] # type: ignore
new_logprobs, entropies = calculate_logprobs(
dtype_for_autocasting,
trainer,
inputs["tokens"],
attn_bias,
+ forward_kwargs,
next_input_ids,
lm_head_t,
chunk_size=chunk_size,
@@ -129,6 +144,7 @@ def compute_loss(
trainer,
inputs["tokens"],
attn_bias,
+ forward_kwargs,
next_input_ids,
lm_head_t,
chunk_size=chunk_size,
@@ -297,6 +313,7 @@ def calculate_logprobs(
trainer: "GRPOTrainer",
input_ids: torch.Tensor,
causal_mask: torch.Tensor,
+ forward_kwargs: dict[str, torch.Tensor],
next_input_ids: torch.Tensor,
lm_head_t: torch.Tensor,
chunk_size: int,
@@ -319,7 +336,7 @@ def calculate_logprobs(
torch.amp.autocast_mode.autocast(device_type="cuda", dtype=dtype_for_autocast),
):
hidden_states = trainer.model( # type: ignore
- input_ids=input_ids, causal_mask=causal_mask
+ input_ids=input_ids, causal_mask=causal_mask, **forward_kwargs
).logits # Shape [B, S, H]
return _calculate_logprobs(lm_head_t, hidden_states, next_input_ids, chunk_size)
diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py
index 85bb2b653..42b7fdab9 100644
--- a/src/art/utils/trajectory_logging.py
+++ b/src/art/utils/trajectory_logging.py
@@ -1,5 +1,5 @@
import json
-from typing import Any, cast
+from typing import Any, Iterator, cast
import yaml
@@ -73,6 +73,9 @@ def message_or_choice_to_dict(message_or_choice: MessageOrChoice) -> dict[str, A
# item is a choice with logprobs, remove the logprobs
item_dict.pop("logprobs")
+ if "content" in item_dict and isinstance(item_dict["content"], Iterator):
+ item_dict["content"] = list(item_dict["content"]) # type: ignore
+
return dict(item_dict)
diff --git a/tests/unit/test_tokenize_trajectory_groups.ipynb b/tests/unit/test_tokenize_trajectory_groups.ipynb
index b90739eda..63ddc2eaf 100644
--- a/tests/unit/test_tokenize_trajectory_groups.ipynb
+++ b/tests/unit/test_tokenize_trajectory_groups.ipynb
@@ -20,7 +20,7 @@
{
"data": {
"text/plain": [
- "TokenizedResult(advantage=-1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nLondon<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'London', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 39572, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], weight=1.0, prompt_id=0, prompt_length=36)"
+ "TokenizedResult(advantage=-1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nLondon<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'London', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 39572, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], pixel_values=None, image_grid_thw=None, weight=1.0, prompt_id=0, prompt_length=36)"
]
},
"metadata": {},
@@ -29,7 +29,7 @@
{
"data": {
"text/plain": [
- "TokenizedResult(advantage=1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nParis<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'Paris', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 59604, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, -0.01, nan, nan], weight=1.0, prompt_id=0, prompt_length=36)"
+ "TokenizedResult(advantage=1.0, chat='<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is the capital of France?<|im_end|>\\n<|im_start|>assistant\\nParis<|im_end|>\\n', tokens=['<|im_start|>', 'system', '\\n', 'You', ' are', ' Q', 'wen', ',', ' created', ' by', ' Alibaba', ' Cloud', '.', ' You', ' are', ' a', ' helpful', ' assistant', '.', '<|im_end|>', '\\n', '<|im_start|>', 'user', '\\n', 'What', ' is', ' the', ' capital', ' of', ' France', '?', '<|im_end|>', '\\n', '<|im_start|>', 'assistant', '\\n', 'Paris', '<|im_end|>', '\\n'], token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 374, 279, 6722, 315, 9625, 30, 151645, 198, 151644, 77091, 198, 59604, 151645, 198], input_pos=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], assistant_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], logprobs=[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, -0.01, nan, nan], pixel_values=None, image_grid_thw=None, weight=1.0, prompt_id=0, prompt_length=36)"
]
},
"metadata": {},