Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions dev/yes-no-maybe-megatron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import asyncio
from itertools import permutations
import os

from dotenv import load_dotenv
import openai

import art
from art.megatron import MegatronBackend


async def rollout(
client: openai.AsyncOpenAI, model_name: str, prompt: str
) -> art.Trajectory:
messages: art.Messages = [{"role": "user", "content": prompt}]
chat_completion = await client.chat.completions.create(
messages=messages, model=model_name, max_tokens=100, timeout=100
)
choice = chat_completion.choices[0]
content = choice.message.content
assert isinstance(content, str)
if content == "yes":
reward = 0.5
elif content == "no":
reward = 0.75
elif content == "maybe":
reward = 1.0
else:
reward = 0.0
return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)


def with_quotes(w: str) -> str:
return f"'{w}'"


async def main():
load_dotenv()

backend = MegatronBackend()
base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507")
model = art.TrainableModel(
name=os.environ.get("MODEL_NAME", "megatron-001"),
project="yes-no-maybe-megatron",
base_model=base_model,
)
await model.register(backend)

prompts = [
f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
for prefix in ["respond", "just respond"]
for use_quotes in [True, False]
for words in (
list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n)
)
]

openai_client = model.openai_client()
max_steps = int(os.environ.get("NUM_STEPS", "20"))
start_step = await model.get_step()

for step in range(start_step, start_step + max_steps):
print(f"\n=== Step {step + 1} ===")
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(openai_client, model.name, prompt) for _ in range(32)
)
for prompt in prompts
)
)
await model.train(
train_groups,
config=art.TrainConfig(learning_rate=1e-4),
)


if __name__ == "__main__":
asyncio.run(main())
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ allowed-unresolved-imports = [
# plotting deps
"matplotlib.**",
"seaborn.**",
# megatron deps
"megatron.**",
]

[dependency-groups]
Expand Down
36 changes: 32 additions & 4 deletions scripts/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,31 @@ if [ -f .env ]; then
done < .env
fi

if ! command -v sudo >/dev/null 2>&1; then
sudo_path="/usr/local/bin/sudo"
if [ ! -w /usr/local/bin ]; then
sudo_path="$HOME/.local/bin/sudo"
mkdir -p "$HOME/.local/bin"
export PATH="$HOME/.local/bin:$PATH"
fi

cat <<'EOF' > "$sudo_path"
#!/bin/sh
exec "$@"
EOF
chmod +x "$sudo_path"
fi

need_pkgs=()
command -v git >/dev/null 2>&1 || need_pkgs+=("git")
command -v curl >/dev/null 2>&1 || need_pkgs+=("curl")
command -v tmux >/dev/null 2>&1 || need_pkgs+=("tmux")

if [ "${#need_pkgs[@]}" -gt 0 ]; then
apt-get update
apt-get install -y "${need_pkgs[@]}"
fi

# Configure git user name and email
git config --global user.name "${GIT_USER_NAME}"
git config --global user.email "${GIT_USER_EMAIL}"
Expand All @@ -29,14 +54,17 @@ else
fi

# Install astral-uv
sudo snap install --classic astral-uv
if ! command -v uv >/dev/null 2>&1; then
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
echo "Failed to install uv." >&2
exit 1
fi
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
fi

# Update uv
uv self update

# Install tmux
apt install tmux -y

# Sync the dependencies
if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then
uv sync --all-extras
Expand Down
1 change: 1 addition & 0 deletions skypilot-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@
workdir: .
resources:
accelerators: ["H100-SXM:1", "H100:1", "A100-80GB:1"]
image_id: docker:pytorch/pytorch:2.9.0-cuda12.8-cudnn9-devel
ports:
- 7999 # main ART server
- 8000 # vLLM server
Expand Down
17 changes: 10 additions & 7 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import math
import os
import shutil
import subprocess
from types import TracebackType
from typing import AsyncIterator, Iterable, Literal, cast
Expand Down Expand Up @@ -570,20 +571,22 @@ async def _train_model(
get_model_dir(model=model, art_path=self._path), next_step
)

# If the current checkpoint exists, rename it to the next step
# If the current checkpoint exists, copy it to the next step
if os.path.exists(current_checkpoint_dir):
os.rename(current_checkpoint_dir, next_checkpoint_dir)
shutil.copytree(
current_checkpoint_dir,
next_checkpoint_dir,
dirs_exist_ok=True,
)
print(
f"Advanced step from {current_step} to {next_step} (no training occurred)"
)

try:
# Register the renamed checkpoint as a new LoRA adapter
# Register the copied checkpoint as a new LoRA adapter
# so it's available for inference at the new step
from ..unsloth.service import UnslothService

if isinstance(service, UnslothService):
await service.register_lora_for_step(
if hasattr(service, "register_lora_for_step"):
await service.register_lora_for_step( # type: ignore[attr-defined]
next_step, next_checkpoint_dir
)
except ModuleNotFoundError:
Expand Down
3 changes: 3 additions & 0 deletions src/art/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .backend import MegatronBackend

__all__ = ["MegatronBackend"]
39 changes: 39 additions & 0 deletions src/art/megatron/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from mp_actors import move_to_child_process

from ..local.backend import LocalBackend
from ..local.service import ModelService
from ..model import TrainableModel
from ..utils.output_dirs import get_model_dir


class MegatronBackend(LocalBackend):
def __init__(
self,
*,
in_process: bool = False,
path: str | None = None,
) -> None:
super().__init__(in_process=in_process, path=path)

async def _get_service(self, model: TrainableModel) -> ModelService:
from ..dev.get_model_config import get_model_config
from .service import MegatronService

if model.name not in self._services:
config = get_model_config(
base_model=model.base_model,
output_dir=get_model_dir(model=model, art_path=self._path),
config=model._internal_config,
)
self._services[model.name] = MegatronService(
model_name=model.name,
base_model=model.base_model,
config=config,
output_dir=get_model_dir(model=model, art_path=self._path),
)
if not self._in_process:
self._services[model.name] = move_to_child_process(
self._services[model.name],
process_name="megatron-service",
)
return self._services[model.name]
Loading