Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/2048/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ async def train():
model,
)

await model.train(
train_groups,
config=art.TrainConfig(learning_rate=1e-5),
result = await backend.train(model, train_groups, learning_rate=1e-5)
await model.log(
train_groups, metrics=result.metrics, step=result.step, split="train"
)


Expand Down
3 changes: 2 additions & 1 deletion examples/benchmarking_comparison_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ async def train_model(model: art.TrainableModel):
)
for scenario in batch.items
)
await model.train(groups)
result = await backend.train(model, groups)
await model.log(groups, metrics=result.metrics, step=result.step, split="train")

if batch.step % 20 == 0:
# Every 20 steps let's benchmark our model under training so we can
Expand Down
8 changes: 5 additions & 3 deletions examples/hn_title_generator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,11 @@ async def main():
)
continue

await model.train(
valid_train_groups,
config=art.TrainConfig(learning_rate=LEARNING_RATE),
result = await backend.train(
model, valid_train_groups, learning_rate=LEARNING_RATE
)
await model.log(
valid_train_groups, metrics=result.metrics, step=result.step, split="train"
)

if batch.step > 0 and batch.step % EVAL_STEPS == 0:
Expand Down
10 changes: 5 additions & 5 deletions examples/just-the-facts/just_the_facts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ async def train(
),
)

await model.train(
result = await backend.train(
model,
groups,
config=art.TrainConfig(learning_rate=model.config.learning_rate),
_config=art.dev.TrainConfig(
scale_rewards=model.config.scale_rewards,
),
learning_rate=model.config.learning_rate,
scale_rewards=model.config.scale_rewards,
)
await model.log(groups, metrics=result.metrics, step=result.step, split="train")

await backend._experimental_push_to_s3(model)

Expand Down
3 changes: 2 additions & 1 deletion examples/mcp-rl/mcp_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ async def train_mcp_agent(model: art.TrainableModel, use_skypilot: bool = False)
await model.log(val_groups, split="val")

print("starting train")
await model.train(groups, config=art.TrainConfig(learning_rate=learning_rate))
result = await backend.train(model, groups, learning_rate=learning_rate)
await model.log(groups, metrics=result.metrics, step=result.step, split="train")

await backend._experimental_push_to_s3(
model,
Expand Down
3 changes: 2 additions & 1 deletion examples/openenv_echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ async def main() -> None:
[art.TrajectoryGroup(rollout(model, env_client) for env_client in env_pool)]
)

await model.train(groups)
result = await backend.train(model, groups)
await model.log(groups, metrics=result.metrics, step=result.step, split="train")


asyncio.run(main())
25 changes: 13 additions & 12 deletions examples/prisoners-dilemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,18 @@
" )\n",
" await model.log([ts[0] for ts in base_play_trajectories], split=\"versus-base\")\n",
" await model.log([ts[1] for ts in base_play_trajectories], split=\"base-model\")\n",
" # Train the model on self-play and base-play trajectories.\n",
" await model.train(\n",
" trajectory_groups=[\n",
" # Since all self-play games have the same starting state and are symmetric, we can gather\n",
" # trajectories from all self-play games into a single trajectory group.\n",
" art.TrajectoryGroup(t for ts in self_play_trajectories for t in ts),\n",
" # We can also gather all base-play _trained model_ trajectories into a single trajectory group.\n",
" # We don't want to train on base model trajectories, because they are sampled from a different distribution.\n",
" art.TrajectoryGroup(ts[0] for ts in base_play_trajectories),\n",
" ],\n",
" config=art.TrainConfig(learning_rate=5e-5),\n",
" # Train the model on self-play and base-play trajectories using the backend-first API.\n",
" trajectory_groups = [\n",
" # Since all self-play games have the same starting state and are symmetric, we can gather\n",
" # trajectories from all self-play games into a single trajectory group.\n",
" art.TrajectoryGroup(t for ts in self_play_trajectories for t in ts),\n",
" # We can also gather all base-play _trained model_ trajectories into a single trajectory group.\n",
" # We don't want to train on base model trajectories, because they are sampled from a different distribution.\n",
" art.TrajectoryGroup(ts[0] for ts in base_play_trajectories),\n",
" ]\n",
" result = await backend.train(model, trajectory_groups, learning_rate=5e-5)\n",
" await model.log(\n",
" trajectory_groups, metrics=result.metrics, step=result.step, split=\"train\"\n",
" )"
]
}
Expand All @@ -172,4 +173,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
13 changes: 7 additions & 6 deletions examples/rock-paper-tool-use.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
"model = art.TrainableModel(\n",
" name=MODEL_NAME, project=\"rock-paper-tool-use\", base_model=BASE_MODEL\n",
")\n",
"await model.register(LocalBackend())\n",
"backend = LocalBackend()\n",
"await model.register(backend)\n",
"client = model.openai_client()\n",
"\n",
"\n",
Expand Down Expand Up @@ -180,10 +181,10 @@
" trajectories = await art.gather_trajectories(\n",
" (rollout() for _ in range(64)), max_exceptions=64\n",
" )\n",
" await model.train(\n",
" [art.TrajectoryGroup(trajectories)],\n",
" config=art.TrainConfig(learning_rate=5e-5),\n",
" )"
" # Log trajectories and train using the backend-first API\n",
" groups = [art.TrajectoryGroup(trajectories)]\n",
" result = await backend.train(model, groups, learning_rate=5e-5)\n",
" await model.log(groups, metrics=result.metrics, step=result.step, split=\"train\")"
]
}
],
Expand All @@ -208,4 +209,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
11 changes: 6 additions & 5 deletions examples/temporal_clue/temporal-clue-7b-async.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,11 @@
" for trajectory in group:\n",
" trajectory.metrics[\"max_reward\"] = max_reward\n",
" await model.delete_checkpoints()\n",
" await model.train(\n",
" train_groups,\n",
" config=art.TrainConfig(learning_rate=5e-6),\n",
" _config=art.dev.TrainConfig(precalculate_logprobs=True),\n",
" result = await backend.train(\n",
" model, train_groups, learning_rate=5e-6, precalculate_logprobs=True\n",
" )\n",
" await model.log(\n",
" train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n",
" )"
]
}
Expand All @@ -185,4 +186,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
13 changes: 9 additions & 4 deletions examples/temporal_clue/temporal-clue-7b.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,15 @@
" trajectory.metrics[\"max_reward\"] = max_reward\n",
" await model.log(val_groups)\n",
" await model.delete_checkpoints()\n",
" await model.train(\n",
" result = await backend.train(\n",
" model,\n",
" train_groups,\n",
" config=art.TrainConfig(learning_rate=5e-6),\n",
" _config=art.dev.TrainConfig(precalculate_logprobs=True, scale_rewards=False),\n",
" learning_rate=5e-6,\n",
" precalculate_logprobs=True,\n",
" scale_rewards=False,\n",
" )\n",
" await model.log(\n",
" train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n",
" )"
]
}
Expand All @@ -147,4 +152,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
9 changes: 5 additions & 4 deletions examples/temporal_clue/temporal-clue-torchtune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@
" )\n",
" await model.log(val_groups)\n",
" await model.delete_checkpoints()\n",
" await model.train(\n",
" train_groups,\n",
" config=art.TrainConfig(learning_rate=5e-6),\n",
" # Log trajectories and train using the backend-first API\n",
" result = await backend.train(model, train_groups, learning_rate=5e-6)\n",
" await model.log(\n",
" train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n",
" )"
]
}
Expand All @@ -175,4 +176,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
6 changes: 3 additions & 3 deletions examples/temporal_clue/temporal-clue.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ async def main():
await model.log(val_groups)
await model.delete_checkpoints()
await backend._experimental_push_to_s3(model)
await model.train(
train_groups,
config=art.TrainConfig(learning_rate=5e-5),
result = await backend.train(model, train_groups, learning_rate=5e-5)
await model.log(
train_groups, metrics=result.metrics, step=result.step, split="train"
)


Expand Down
5 changes: 4 additions & 1 deletion examples/tic_tac_toe/tic-tac-toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ async def main():
pbar_desc="gather",
)
await model.delete_checkpoints()
await model.train(train_groups, config=art.TrainConfig(learning_rate=5e-5))
result = await backend.train(model, train_groups, learning_rate=5e-5)
await model.log(
train_groups, metrics=result.metrics, step=result.step, split="train"
)
await backend._experimental_push_to_s3(model)

if DEPLOY_MODEL:
Expand Down
10 changes: 6 additions & 4 deletions examples/tic_tac_toe_self_play/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,12 @@ async def main():
await model.log(model_trajectories, split="val")

# await model.delete_checkpoints()
await model.train(
trajectory_groups=[x_trajectory_group, o_trajectory_group],
config=art.TrainConfig(learning_rate=2e-5),
verbose=True,
trajectory_groups = [x_trajectory_group, o_trajectory_group]
result = await backend.train(
model, trajectory_groups, learning_rate=2e-5, verbose=True
)
await model.log(
trajectory_groups, metrics=result.metrics, step=result.step, split="train"
)
await backend._experimental_push_to_s3(model)

Expand Down
10 changes: 6 additions & 4 deletions examples/tic_tac_toe_self_play/train_o4_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ def get_model_trajectories(
await model.log(model_trajectories, split="val")

# await model.delete_checkpoints()
await model.train(
trajectory_groups=[x_trajectory_group, o_trajectory_group],
config=art.TrainConfig(learning_rate=2e-5),
verbose=True,
trajectory_groups = [x_trajectory_group, o_trajectory_group]
result = await backend.train(
model, trajectory_groups, learning_rate=2e-5, verbose=True
)
await model.log(
trajectory_groups, metrics=result.metrics, step=result.step, split="train"
)
await backend._experimental_push_to_s3(model)

Expand Down
15 changes: 14 additions & 1 deletion src/art/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,20 @@ def __init__(self, **kwargs):
from .backend import Backend
from .batches import trajectory_group_batches
from .gather import gather_trajectories, gather_trajectory_groups
from .local import LocalBackend
from .model import Model, TrainableModel
from .serverless import ServerlessBackend
from .tinker import TinkerBackend
from .trajectories import Trajectory, TrajectoryGroup
from .types import Messages, MessagesAndChoices, Tools, TrainConfig
from .types import (
LocalTrainResult,
Messages,
MessagesAndChoices,
ServerlessTrainResult,
Tools,
TrainConfig,
TrainResult,
)
from .utils import retry
from .yield_trajectory import capture_yielded_trajectory, yield_trajectory

Expand All @@ -70,14 +79,18 @@ def __init__(self, **kwargs):
"gather_trajectory_groups",
"trajectory_group_batches",
"Backend",
"LocalBackend",
"LocalTrainResult",
"ServerlessBackend",
"ServerlessTrainResult",
"Messages",
"MessagesAndChoices",
"Tools",
"Model",
"TrainableModel",
"retry",
"TrainConfig",
"TrainResult",
"TinkerBackend",
"Trajectory",
"TrajectoryGroup",
Expand Down
36 changes: 34 additions & 2 deletions src/art/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING, AsyncIterator, Literal
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal
import warnings

import httpx
Expand All @@ -15,7 +15,7 @@

from . import dev
from .trajectories import TrajectoryGroup
from .types import TrainConfig
from .types import TrainConfig, TrainResult

if TYPE_CHECKING:
from .model import Model, TrainableModel
Expand Down Expand Up @@ -80,6 +80,38 @@ async def _prepare_backend_for_training(
base_url, api_key = tuple(response.json())
return base_url, api_key

def _model_inference_name(self, model: "Model", step: int | None = None) -> str:
"""Return the inference name for a model checkpoint.

Override in subclasses to provide backend-specific naming.
Default implementation returns model.name with optional @step suffix.
"""
base_name = model.inference_model_name or model.name
if step is not None:
return f"{base_name}@{step}"
return base_name

async def train(
self,
model: "TrainableModel",
trajectory_groups: Iterable[TrajectoryGroup],
**kwargs: Any,
) -> TrainResult:
"""Train the model on the given trajectory groups.

This method is not implemented in the base Backend class. Use
LocalBackend, ServerlessBackend, or TinkerBackend directly for training.

Raises:
NotImplementedError: Always raised. Use a concrete backend instead.
"""
raise NotImplementedError(
"The base Backend class does not support the train() method. "
"Use LocalBackend, ServerlessBackend, or TinkerBackend directly. "
"If you are using the 'art run' server, consider using LocalBackend "
"in-process instead."
)

async def _train_model(
self,
model: "TrainableModel",
Expand Down
Loading