Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fdef126
chore: Start adding Math Vista task
bradhilton Sep 17, 2025
87de7f1
feat: Add text-only rollout function
bradhilton Sep 17, 2025
5297dd9
refactor: Simplify message handling in Math Vista notebook
bradhilton Sep 17, 2025
ff8bbb7
feat: Enhance Math Vista notebook with execution tracking and image h…
bradhilton Sep 17, 2025
e316551
fix: Update run data paths and timestamps in Math Vista notebook
bradhilton Sep 17, 2025
4a1c34e
fix: Reset execution counts and clean up outputs in Math Vista notebook
bradhilton Sep 17, 2025
657960e
Merge branch 'main' into feat/vlm-support
bradhilton Oct 14, 2025
030ff87
Merge branch 'main' into feat/vlm-support
bradhilton Oct 14, 2025
5343690
chore: Ruff linting autofix
bradhilton Oct 14, 2025
12124b0
chore: Cast types
bradhilton Oct 14, 2025
dbfbcd3
fix: Update execution counts and outputs in math-vista notebook
bradhilton Oct 14, 2025
9d426d2
refactor: Enhance image processing integration in tokenizer and backend
bradhilton Oct 15, 2025
f233b30
refactor: Enhance tensor handling in packing and tokenization
bradhilton Oct 15, 2025
170acda
refactor: Improve tensor handling and execution tracking in math-vist…
bradhilton Oct 15, 2025
9645b86
Merge branch 'main' into feat/vlm-support
bradhilton Oct 15, 2025
6a8977c
refactor: Optimize token indexing in trajectory tokenization
bradhilton Oct 15, 2025
b81c2dc
chore: Clean up math-vista.ipynb
bradhilton Oct 16, 2025
c0ab347
refactor: Improve LoRA handling in UnslothService and ModelState
bradhilton Oct 16, 2025
cf66a19
refactor: Enhance tensor handling in UnslothService during warmup
bradhilton Oct 16, 2025
01cd44f
feat: Add image generation and training notebook for yes-no-maybe vision
bradhilton Oct 17, 2025
585f903
feat: Implement MathVista training script for image-based question an…
bradhilton Oct 17, 2025
01d2358
refactor: Update tensor handling in DecoupledUnslothService and traje…
bradhilton Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions dev/math-vista/math-vista.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "46a6ad6d",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96d51078",
"metadata": {},
"outputs": [],
"source": [
"%%html\n",
"<style>\n",
".cell-output-ipywidget-background {\n",
" background-color: transparent !important;\n",
"}\n",
":root {\n",
" --jp-widgets-color: var(--vscode-editor-foreground);\n",
" --jp-widgets-font-size: var(--vscode-editor-font-size);\n",
"} \n",
"</style>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7dd70e04",
"metadata": {},
"outputs": [],
"source": [
"import polars as pl\n",
"\n",
"splits = {\n",
" \"testmini\": \"data/testmini-00000-of-00001-725687bf7a18d64b.parquet\",\n",
" \"test\": \"data/test-*.parquet\",\n",
"}\n",
"df = pl.read_parquet(\"hf://datasets/AI4Math/MathVista/\" + splits[\"testmini\"]).sample(\n",
" fraction=1.0, shuffle=True, seed=42\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81e02b97",
"metadata": {},
"outputs": [],
"source": [
"from typing import Iterator, TypedDict, cast\n",
"\n",
"\n",
"class DecodedImage(TypedDict):\n",
" bytes: bytes\n",
"\n",
"\n",
"class Scenario(TypedDict):\n",
" pid: int\n",
" question: str\n",
" answer: str\n",
" image: str\n",
" decoded_image: DecodedImage\n",
"\n",
"\n",
"val_scenarios = cast(list[Scenario], df.head(64).to_dicts())\n",
"train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9287d8fe",
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"\n",
"import art\n",
"from art.local import LocalBackend\n",
"\n",
"model = art.TrainableModel(\n",
" name=\"002\",\n",
" project=\"math-vista\",\n",
" base_model=\"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
")\n",
"backend = LocalBackend()\n",
"await model.register(backend)\n",
"client = model.openai_client()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c92b4b11",
"metadata": {},
"outputs": [],
"source": [
"async def rollout(scenario: Scenario) -> art.Trajectory:\n",
" image_path = f\"/tmp/{scenario['image']}\"\n",
"\n",
" import os\n",
"\n",
" os.makedirs(os.path.dirname(image_path), exist_ok=True)\n",
"\n",
" with open(image_path, \"wb\") as f:\n",
" f.write(scenario[\"decoded_image\"][\"bytes\"])\n",
"\n",
" trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)\n",
" trajectory.messages_and_choices = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": scenario[\"question\"]\n",
" + \"\\n\\nNote: Provide your answer in a LaTeX box.\",\n",
" },\n",
" {\"type\": \"image_url\", \"image_url\": {\"url\": f\"file://{image_path}\"}},\n",
" ],\n",
" }\n",
" ]\n",
" chat_completion = await client.chat.completions.create(\n",
" model=model.name, messages=trajectory.messages()\n",
" )\n",
" choice = chat_completion.choices[0]\n",
" trajectory.messages_and_choices.append(choice)\n",
" content = choice.message.content\n",
" assert content is not None\n",
" if matches := list(re.finditer(r\"\\\\boxed\\{(.*?)\\}\", content, re.DOTALL)):\n",
" match = matches[-1]\n",
" answer = match.group(1)\n",
" if answer.lower() == scenario[\"answer\"].lower():\n",
" trajectory.reward = 1.0\n",
" return trajectory"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "359e530d",
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"import itertools\n",
"\n",
"SCENARIOS_PER_STEP = 8\n",
"TRAJECTORY_GROUP_SIZE = 8\n",
"start = await model.get_step()\n",
"train_scenarios_iter = itertools.cycle(train_scenarios_iter)\n",
"for _ in range(start * SCENARIOS_PER_STEP):\n",
" next(train_scenarios_iter)\n",
"\n",
"for i in range(start, 1000):\n",
" train_scenarios = [next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)]\n",
" val_trajectories, train_trajectory_groups = await asyncio.gather(\n",
" art.gather_trajectories(\n",
" (rollout(scenario) for scenario in val_scenarios),\n",
" pbar_desc=\"gather(val)\",\n",
" max_exceptions=32,\n",
" ),\n",
" art.gather_trajectory_groups(\n",
" (\n",
" art.TrajectoryGroup(\n",
" rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)\n",
" )\n",
" for scenario in train_scenarios\n",
" ),\n",
" pbar_desc=\"gather(train)\",\n",
" max_exceptions=32,\n",
" ),\n",
" )\n",
" await model.log(val_trajectories)\n",
" await model.train(train_trajectory_groups)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
136 changes: 136 additions & 0 deletions dev/math-vista/math-vista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import argparse
import asyncio
import itertools
import os
import re
from typing import Iterator, TypedDict, cast

import polars as pl

import art
from art.local import LocalBackend


class DecodedImage(TypedDict):
bytes: bytes


class Scenario(TypedDict):
pid: int
question: str
answer: str
image: str
decoded_image: DecodedImage


async def main(model_name: str, steps: int) -> None:
# Load and shuffle the dataset
df = pl.read_parquet(
"hf://datasets/AI4Math/MathVista/data/testmini-00000-of-00001-725687bf7a18d64b.parquet"
).sample(fraction=1.0, shuffle=True, seed=42)

val_scenarios = cast(list[Scenario], df.head(64).to_dicts())
train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))

# Initialize trainable model and backend
model = art.TrainableModel(
name=model_name,
project="math-vista",
base_model="Qwen/Qwen2.5-VL-7B-Instruct",
)

async def rollout(scenario: Scenario) -> art.Trajectory:
image_path = f"/tmp/{scenario['image']}"
os.makedirs(os.path.dirname(image_path), exist_ok=True)
with open(image_path, "wb") as f:
f.write(scenario["decoded_image"]["bytes"])

trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)
trajectory.messages_and_choices = [
{
"role": "user",
"content": [
{
"type": "text",
"text": scenario["question"]
+ "\n\nNote: Provide your answer in a LaTeX box.",
},
{"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
],
}
]

chat_completion = await client.chat.completions.create(
model=model.name, messages=trajectory.messages()
)
choice = chat_completion.choices[0]
trajectory.messages_and_choices.append(choice)
content = choice.message.content
assert content is not None

if matches := list(re.finditer(r"\\boxed\{(.*?)\}", content, re.DOTALL)):
match = matches[-1]
answer = match.group(1)
if answer.lower() == scenario["answer"].lower():
trajectory.reward = 1.0
return trajectory

SCENARIOS_PER_STEP = 8
TRAJECTORY_GROUP_SIZE = 8

with LocalBackend() as backend:
await model.register(backend)
client = model.openai_client()

start = await model.get_step()
train_scenarios_iter = itertools.cycle(train_scenarios_iter)
for _ in range(start * SCENARIOS_PER_STEP):
next(train_scenarios_iter)

# Training loop
for _ in range(start, steps):
train_scenarios = [
next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)
]
val_trajectories, train_trajectory_groups = await asyncio.gather(
art.gather_trajectories(
(rollout(scenario) for scenario in val_scenarios),
pbar_desc="gather(val)",
max_exceptions=32,
),
art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)
)
for scenario in train_scenarios
),
pbar_desc="gather(train)",
max_exceptions=32,
),
)
await model.log(val_trajectories)
await model.train(train_trajectory_groups)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Minimal MathVista trainer script")
parser.add_argument(
"-n",
"--name",
required=True,
help="Run/model name to use for the TrainableModel",
)
parser.add_argument(
"-s",
"--steps",
type=int,
default=1000,
help="Number of training steps to run",
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
asyncio.run(main(args.name, args.steps))
Loading