diff --git a/examples/2048/train.py b/examples/2048/train.py index 11c6edf66..7cf8e3070 100644 --- a/examples/2048/train.py +++ b/examples/2048/train.py @@ -71,9 +71,9 @@ async def train(): model, ) - await model.train( - train_groups, - config=art.TrainConfig(learning_rate=1e-5), + result = await backend.train(model, train_groups, learning_rate=1e-5) + await model.log( + train_groups, metrics=result.metrics, step=result.step, split="train" ) diff --git a/examples/benchmarking_comparison_models.py b/examples/benchmarking_comparison_models.py index 9938add20..7f8a989e3 100644 --- a/examples/benchmarking_comparison_models.py +++ b/examples/benchmarking_comparison_models.py @@ -127,7 +127,8 @@ async def train_model(model: art.TrainableModel): ) for scenario in batch.items ) - await model.train(groups) + result = await backend.train(model, groups) + await model.log(groups, metrics=result.metrics, step=result.step, split="train") if batch.step % 20 == 0: # Every 20 steps let's benchmark our model under training so we can diff --git a/examples/hn_title_generator/train.py b/examples/hn_title_generator/train.py index b179263a8..abf4400b6 100644 --- a/examples/hn_title_generator/train.py +++ b/examples/hn_title_generator/train.py @@ -325,9 +325,11 @@ async def main(): ) continue - await model.train( - valid_train_groups, - config=art.TrainConfig(learning_rate=LEARNING_RATE), + result = await backend.train( + model, valid_train_groups, learning_rate=LEARNING_RATE + ) + await model.log( + valid_train_groups, metrics=result.metrics, step=result.step, split="train" ) if batch.step > 0 and batch.step % EVAL_STEPS == 0: diff --git a/examples/just-the-facts/just_the_facts/train.py b/examples/just-the-facts/just_the_facts/train.py index 3941c7a9b..621132adb 100644 --- a/examples/just-the-facts/just_the_facts/train.py +++ b/examples/just-the-facts/just_the_facts/train.py @@ -81,13 +81,13 @@ async def train( ), ) - await model.train( + result = await backend.train( + model, groups, - config=art.TrainConfig(learning_rate=model.config.learning_rate), - _config=art.dev.TrainConfig( - scale_rewards=model.config.scale_rewards, - ), + learning_rate=model.config.learning_rate, + scale_rewards=model.config.scale_rewards, ) + await model.log(groups, metrics=result.metrics, step=result.step, split="train") await backend._experimental_push_to_s3(model) diff --git a/examples/mcp-rl/mcp_rl/train.py b/examples/mcp-rl/mcp_rl/train.py index 3417423d2..020fc0e8d 100644 --- a/examples/mcp-rl/mcp_rl/train.py +++ b/examples/mcp-rl/mcp_rl/train.py @@ -168,7 +168,8 @@ async def train_mcp_agent(model: art.TrainableModel, use_skypilot: bool = False) await model.log(val_groups, split="val") print("starting train") - await model.train(groups, config=art.TrainConfig(learning_rate=learning_rate)) + result = await backend.train(model, groups, learning_rate=learning_rate) + await model.log(groups, metrics=result.metrics, step=result.step, split="train") await backend._experimental_push_to_s3( model, diff --git a/examples/openenv_echo.py b/examples/openenv_echo.py index 8da3e8af9..3c3ad2e86 100644 --- a/examples/openenv_echo.py +++ b/examples/openenv_echo.py @@ -86,7 +86,8 @@ async def main() -> None: [art.TrajectoryGroup(rollout(model, env_client) for env_client in env_pool)] ) - await model.train(groups) + result = await backend.train(model, groups) + await model.log(groups, metrics=result.metrics, step=result.step, split="train") asyncio.run(main()) diff --git a/examples/prisoners-dilemma.ipynb b/examples/prisoners-dilemma.ipynb index ec6d31e7d..b04f7ad3d 100644 --- a/examples/prisoners-dilemma.ipynb +++ b/examples/prisoners-dilemma.ipynb @@ -136,17 +136,18 @@ " )\n", " await model.log([ts[0] for ts in base_play_trajectories], split=\"versus-base\")\n", " await model.log([ts[1] for ts in base_play_trajectories], split=\"base-model\")\n", - " # Train the model on self-play and base-play trajectories.\n", - " await model.train(\n", - " trajectory_groups=[\n", - " # Since all self-play games have the same starting state and are symmetric, we can gather\n", - " # trajectories from all self-play games into a single trajectory group.\n", - " art.TrajectoryGroup(t for ts in self_play_trajectories for t in ts),\n", - " # We can also gather all base-play _trained model_ trajectories into a single trajectory group.\n", - " # We don't want to train on base model trajectories, because they are sampled from a different distribution.\n", - " art.TrajectoryGroup(ts[0] for ts in base_play_trajectories),\n", - " ],\n", - " config=art.TrainConfig(learning_rate=5e-5),\n", + " # Train the model on self-play and base-play trajectories using the backend-first API.\n", + " trajectory_groups = [\n", + " # Since all self-play games have the same starting state and are symmetric, we can gather\n", + " # trajectories from all self-play games into a single trajectory group.\n", + " art.TrajectoryGroup(t for ts in self_play_trajectories for t in ts),\n", + " # We can also gather all base-play _trained model_ trajectories into a single trajectory group.\n", + " # We don't want to train on base model trajectories, because they are sampled from a different distribution.\n", + " art.TrajectoryGroup(ts[0] for ts in base_play_trajectories),\n", + " ]\n", + " result = await backend.train(model, trajectory_groups, learning_rate=5e-5)\n", + " await model.log(\n", + " trajectory_groups, metrics=result.metrics, step=result.step, split=\"train\"\n", " )" ] } @@ -172,4 +173,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/rock-paper-tool-use.ipynb b/examples/rock-paper-tool-use.ipynb index f3e0a4497..01758d13b 100644 --- a/examples/rock-paper-tool-use.ipynb +++ b/examples/rock-paper-tool-use.ipynb @@ -52,7 +52,8 @@ "model = art.TrainableModel(\n", " name=MODEL_NAME, project=\"rock-paper-tool-use\", base_model=BASE_MODEL\n", ")\n", - "await model.register(LocalBackend())\n", + "backend = LocalBackend()\n", + "await model.register(backend)\n", "client = model.openai_client()\n", "\n", "\n", @@ -180,10 +181,10 @@ " trajectories = await art.gather_trajectories(\n", " (rollout() for _ in range(64)), max_exceptions=64\n", " )\n", - " await model.train(\n", - " [art.TrajectoryGroup(trajectories)],\n", - " config=art.TrainConfig(learning_rate=5e-5),\n", - " )" + " # Log trajectories and train using the backend-first API\n", + " groups = [art.TrajectoryGroup(trajectories)]\n", + " result = await backend.train(model, groups, learning_rate=5e-5)\n", + " await model.log(groups, metrics=result.metrics, step=result.step, split=\"train\")" ] } ], @@ -208,4 +209,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/temporal_clue/temporal-clue-7b-async.ipynb b/examples/temporal_clue/temporal-clue-7b-async.ipynb index 6d0e57fc8..929bfd017 100644 --- a/examples/temporal_clue/temporal-clue-7b-async.ipynb +++ b/examples/temporal_clue/temporal-clue-7b-async.ipynb @@ -156,10 +156,11 @@ " for trajectory in group:\n", " trajectory.metrics[\"max_reward\"] = max_reward\n", " await model.delete_checkpoints()\n", - " await model.train(\n", - " train_groups,\n", - " config=art.TrainConfig(learning_rate=5e-6),\n", - " _config=art.dev.TrainConfig(precalculate_logprobs=True),\n", + " result = await backend.train(\n", + " model, train_groups, learning_rate=5e-6, precalculate_logprobs=True\n", + " )\n", + " await model.log(\n", + " train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n", " )" ] } @@ -185,4 +186,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/temporal_clue/temporal-clue-7b.ipynb b/examples/temporal_clue/temporal-clue-7b.ipynb index 187455fb8..e75dad331 100644 --- a/examples/temporal_clue/temporal-clue-7b.ipynb +++ b/examples/temporal_clue/temporal-clue-7b.ipynb @@ -118,10 +118,15 @@ " trajectory.metrics[\"max_reward\"] = max_reward\n", " await model.log(val_groups)\n", " await model.delete_checkpoints()\n", - " await model.train(\n", + " result = await backend.train(\n", + " model,\n", " train_groups,\n", - " config=art.TrainConfig(learning_rate=5e-6),\n", - " _config=art.dev.TrainConfig(precalculate_logprobs=True, scale_rewards=False),\n", + " learning_rate=5e-6,\n", + " precalculate_logprobs=True,\n", + " scale_rewards=False,\n", + " )\n", + " await model.log(\n", + " train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n", " )" ] } @@ -147,4 +152,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/temporal_clue/temporal-clue-torchtune.ipynb b/examples/temporal_clue/temporal-clue-torchtune.ipynb index 041e7315c..7996d5db6 100644 --- a/examples/temporal_clue/temporal-clue-torchtune.ipynb +++ b/examples/temporal_clue/temporal-clue-torchtune.ipynb @@ -147,9 +147,10 @@ " )\n", " await model.log(val_groups)\n", " await model.delete_checkpoints()\n", - " await model.train(\n", - " train_groups,\n", - " config=art.TrainConfig(learning_rate=5e-6),\n", + " # Log trajectories and train using the backend-first API\n", + " result = await backend.train(model, train_groups, learning_rate=5e-6)\n", + " await model.log(\n", + " train_groups, metrics=result.metrics, step=result.step, split=\"train\"\n", " )" ] } @@ -175,4 +176,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/temporal_clue/temporal-clue.py b/examples/temporal_clue/temporal-clue.py index 34c91efb4..e4fc078de 100644 --- a/examples/temporal_clue/temporal-clue.py +++ b/examples/temporal_clue/temporal-clue.py @@ -85,9 +85,9 @@ async def main(): await model.log(val_groups) await model.delete_checkpoints() await backend._experimental_push_to_s3(model) - await model.train( - train_groups, - config=art.TrainConfig(learning_rate=5e-5), + result = await backend.train(model, train_groups, learning_rate=5e-5) + await model.log( + train_groups, metrics=result.metrics, step=result.step, split="train" ) diff --git a/examples/tic_tac_toe/tic-tac-toe.py b/examples/tic_tac_toe/tic-tac-toe.py index 65916f81e..15f9ac043 100644 --- a/examples/tic_tac_toe/tic-tac-toe.py +++ b/examples/tic_tac_toe/tic-tac-toe.py @@ -71,7 +71,10 @@ async def main(): pbar_desc="gather", ) await model.delete_checkpoints() - await model.train(train_groups, config=art.TrainConfig(learning_rate=5e-5)) + result = await backend.train(model, train_groups, learning_rate=5e-5) + await model.log( + train_groups, metrics=result.metrics, step=result.step, split="train" + ) await backend._experimental_push_to_s3(model) if DEPLOY_MODEL: diff --git a/examples/tic_tac_toe_self_play/train.py b/examples/tic_tac_toe_self_play/train.py index fc615e243..d18112942 100644 --- a/examples/tic_tac_toe_self_play/train.py +++ b/examples/tic_tac_toe_self_play/train.py @@ -140,10 +140,12 @@ async def main(): await model.log(model_trajectories, split="val") # await model.delete_checkpoints() - await model.train( - trajectory_groups=[x_trajectory_group, o_trajectory_group], - config=art.TrainConfig(learning_rate=2e-5), - verbose=True, + trajectory_groups = [x_trajectory_group, o_trajectory_group] + result = await backend.train( + model, trajectory_groups, learning_rate=2e-5, verbose=True + ) + await model.log( + trajectory_groups, metrics=result.metrics, step=result.step, split="train" ) await backend._experimental_push_to_s3(model) diff --git a/examples/tic_tac_toe_self_play/train_o4_mini.py b/examples/tic_tac_toe_self_play/train_o4_mini.py index 004560a3e..325c84109 100644 --- a/examples/tic_tac_toe_self_play/train_o4_mini.py +++ b/examples/tic_tac_toe_self_play/train_o4_mini.py @@ -149,10 +149,12 @@ def get_model_trajectories( await model.log(model_trajectories, split="val") # await model.delete_checkpoints() - await model.train( - trajectory_groups=[x_trajectory_group, o_trajectory_group], - config=art.TrainConfig(learning_rate=2e-5), - verbose=True, + trajectory_groups = [x_trajectory_group, o_trajectory_group] + result = await backend.train( + model, trajectory_groups, learning_rate=2e-5, verbose=True + ) + await model.log( + trajectory_groups, metrics=result.metrics, step=result.step, split="train" ) await backend._experimental_push_to_s3(model) diff --git a/src/art/__init__.py b/src/art/__init__.py index cbb31bf4a..75957e176 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -54,11 +54,20 @@ def __init__(self, **kwargs): from .backend import Backend from .batches import trajectory_group_batches from .gather import gather_trajectories, gather_trajectory_groups +from .local import LocalBackend from .model import Model, TrainableModel from .serverless import ServerlessBackend from .tinker import TinkerBackend from .trajectories import Trajectory, TrajectoryGroup -from .types import Messages, MessagesAndChoices, Tools, TrainConfig +from .types import ( + LocalTrainResult, + Messages, + MessagesAndChoices, + ServerlessTrainResult, + Tools, + TrainConfig, + TrainResult, +) from .utils import retry from .yield_trajectory import capture_yielded_trajectory, yield_trajectory @@ -70,7 +79,10 @@ def __init__(self, **kwargs): "gather_trajectory_groups", "trajectory_group_batches", "Backend", + "LocalBackend", + "LocalTrainResult", "ServerlessBackend", + "ServerlessTrainResult", "Messages", "MessagesAndChoices", "Tools", @@ -78,6 +90,7 @@ def __init__(self, **kwargs): "TrainableModel", "retry", "TrainConfig", + "TrainResult", "TinkerBackend", "Trajectory", "TrajectoryGroup", diff --git a/src/art/backend.py b/src/art/backend.py index b1b6f78ff..8d07e7153 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal import warnings import httpx @@ -15,7 +15,7 @@ from . import dev from .trajectories import TrajectoryGroup -from .types import TrainConfig +from .types import TrainConfig, TrainResult if TYPE_CHECKING: from .model import Model, TrainableModel @@ -80,6 +80,38 @@ async def _prepare_backend_for_training( base_url, api_key = tuple(response.json()) return base_url, api_key + def _model_inference_name(self, model: "Model", step: int | None = None) -> str: + """Return the inference name for a model checkpoint. + + Override in subclasses to provide backend-specific naming. + Default implementation returns model.name with optional @step suffix. + """ + base_name = model.inference_model_name or model.name + if step is not None: + return f"{base_name}@{step}" + return base_name + + async def train( + self, + model: "TrainableModel", + trajectory_groups: Iterable[TrajectoryGroup], + **kwargs: Any, + ) -> TrainResult: + """Train the model on the given trajectory groups. + + This method is not implemented in the base Backend class. Use + LocalBackend, ServerlessBackend, or TinkerBackend directly for training. + + Raises: + NotImplementedError: Always raised. Use a concrete backend instead. + """ + raise NotImplementedError( + "The base Backend class does not support the train() method. " + "Use LocalBackend, ServerlessBackend, or TinkerBackend directly. " + "If you are using the 'art run' server, consider using LocalBackend " + "in-process instead." + ) + async def _train_model( self, model: "TrainableModel", diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 22b3f4f09..0301bea91 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -4,7 +4,7 @@ import os import subprocess from types import TracebackType -from typing import AsyncIterator, Literal, cast +from typing import AsyncIterator, Iterable, Literal, cast import warnings import aiohttp @@ -41,7 +41,7 @@ ) from ..preprocessing.tokenize import tokenize_trajectory_groups from ..trajectories import Trajectory, TrajectoryGroup -from ..types import Message, TrainConfig +from ..types import LocalTrainResult, Message, TrainConfig from ..utils import format_message, get_model_step from .checkpoints import ( delete_checkpoints, @@ -117,6 +117,22 @@ async def register( if model.trainable and "WANDB_API_KEY" in os.environ: _ = model._get_wandb_run() + def _model_inference_name(self, model: Model, step: int | None = None) -> str: + """Return the inference name for a model checkpoint. + + For LocalBackend with vLLM, the base model is served under its HF name, + and LoRA adapters are served as `model.name@step`. + + Args: + model: The model. + step: If provided, returns name for specific checkpoint. + If None, returns name for latest checkpoint (step 0 initially). + """ + # For LocalBackend, vLLM always serves LoRA adapters with @step suffix + # Default to step 0 when not specified (the initial checkpoint created at registration) + actual_step = step if step is not None else 0 + return f"{model.name}@{actual_step}" + async def _get_service(self, model: TrainableModel) -> ModelService: from ..dev.get_model_config import get_model_config @@ -350,6 +366,160 @@ def _trajectory_log(self, trajectory: Trajectory) -> str: formatted_messages.append(format_message(message)) return header + "\n".join(formatted_messages) + async def train( # type: ignore[override] + self, + model: TrainableModel, + trajectory_groups: Iterable[TrajectoryGroup], + *, + # Core training parameters + learning_rate: float = 5e-6, + beta: float = 0.0, + # RL algorithm settings + ppo: bool = False, + epsilon: float | None = None, + epsilon_high: float | None = None, + # Advantage computation + advantage_balance: float = 0.0, + scale_rewards: bool = True, + # Importance sampling + importance_sampling_level: Literal[ + "token", "sequence", "average", "geometric_average" + ] = "token", + max_negative_advantage_importance_sampling_weight: float | None = None, + mask_prob_ratio: bool = False, + # Experimental parameters + kimi_k2_tau: float | None = None, + precalculate_logprobs: bool = False, + # LocalBackend-specific parameters + allow_training_without_logprobs: bool = False, + plot_tensors: bool = False, + truncated_importance_sampling: float | None = None, + scale_learning_rate_by_reward_std_dev: bool = False, + logprob_calculation_chunk_size: int = 1024, + num_trajectories_learning_rate_multiplier_power: float = 0.0, + # Checkpoint behavior + save_checkpoint: bool = True, + # Verbosity + verbose: bool = False, + ) -> LocalTrainResult: + """Train the model on the given trajectory groups. + + This is the recommended way to train models. Unlike model.train(), this + method does NOT automatically log trajectories or metrics. Call model.log() + explicitly before and/or after training if you want to log data. + + Args: + model: The trainable model to train. + trajectory_groups: Batches of trajectories to train on. + learning_rate: Learning rate for training. Defaults to 5e-6. + beta: KL penalty coefficient. Defaults to 0.0. + ppo: Whether to use PPO clipping. Defaults to False. + epsilon: Clip epsilon for importance sampling. Defaults based on ppo. + epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. + advantage_balance: Balance between negative and positive advantages + in range [-1.0, 1.0]. Defaults to 0.0 (balanced). + scale_rewards: Whether to scale rewards by standard deviation. + Defaults to True. + importance_sampling_level: Level at which to compute importance + sampling weights. Defaults to "token". + max_negative_advantage_importance_sampling_weight: Maximum weight + for negative advantage samples. + mask_prob_ratio: Whether to mask probability ratios. Defaults to False. + kimi_k2_tau: Tau parameter for Kimi K2 algorithm. + precalculate_logprobs: Whether to precalculate logprobs. + allow_training_without_logprobs: Allow training even when no logprobs + are available. Defaults to False. + plot_tensors: Whether to plot training tensors for debugging. + Defaults to False. + truncated_importance_sampling: Truncation threshold for importance + sampling weights. + scale_learning_rate_by_reward_std_dev: Whether to scale learning rate + by reward standard deviation. Defaults to False. + logprob_calculation_chunk_size: Chunk size for logprob calculation. + Defaults to 1024. + num_trajectories_learning_rate_multiplier_power: Power for learning + rate multiplier based on number of trajectories. + save_checkpoint: Whether to save a checkpoint after training. + Defaults to True. + verbose: Whether to print verbose output. Defaults to False. + + Returns: + LocalTrainResult with step number, training metrics, and checkpoint path. + + Example: + # Before (deprecated): + await model.train(trajectory_groups, config=TrainConfig(learning_rate=5e-6)) + + # After (recommended): + await model.log(trajectory_groups, split="train") + result = await backend.train(model, trajectory_groups, learning_rate=5e-6) + # Optionally log training metrics: + # await model.log(metrics=result.metrics, step=result.step) + """ + groups_list = list(trajectory_groups) + + # Build config objects from explicit kwargs + config = TrainConfig(learning_rate=learning_rate, beta=beta) + dev_config: dev.TrainConfig = { + "advantage_balance": advantage_balance, + "allow_training_without_logprobs": allow_training_without_logprobs, + "importance_sampling_level": importance_sampling_level, + "mask_prob_ratio": mask_prob_ratio, + "plot_tensors": plot_tensors, + "ppo": ppo, + "precalculate_logprobs": precalculate_logprobs, + "scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev, + "scale_rewards": scale_rewards, + "logprob_calculation_chunk_size": logprob_calculation_chunk_size, + "num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power, + } + # Only include optional fields if they're set + if epsilon is not None: + dev_config["epsilon"] = epsilon + if epsilon_high is not None: + dev_config["epsilon_high"] = epsilon_high + if max_negative_advantage_importance_sampling_weight is not None: + dev_config["max_negative_advantage_importance_sampling_weight"] = ( + max_negative_advantage_importance_sampling_weight + ) + if kimi_k2_tau is not None: + dev_config["kimi_k2_tau"] = kimi_k2_tau + if truncated_importance_sampling is not None: + dev_config["truncated_importance_sampling"] = truncated_importance_sampling + + # Collect metrics from training + training_metrics: list[dict[str, float]] = [] + async for metrics in self._train_model( + model, groups_list, config, dev_config, verbose + ): + training_metrics.append(metrics) + + # Aggregate metrics + avg_metrics: dict[str, float] = {} + if training_metrics: + avg_metrics = { + k: sum(d.get(k, 0) for d in training_metrics) + / sum(1 for d in training_metrics if k in d) + for k in {k for d in training_metrics for k in d} + if k != "num_gradient_steps" + } + + # Get step and checkpoint path + step = await self._get_step(model) + checkpoint_path: str | None = None + if save_checkpoint: + checkpoint_path = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=self._path), step + ) + if not os.path.exists(checkpoint_path): + checkpoint_path = None + + return LocalTrainResult( + step=step, + metrics=avg_metrics, + checkpoint_path=checkpoint_path, + ) + async def _train_model( self, model: TrainableModel, diff --git a/src/art/model.py b/src/art/model.py index b9d47a233..3e473cd83 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -2,6 +2,7 @@ import json import os from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload +import warnings import httpx from openai import AsyncOpenAI, DefaultAsyncHttpxClient @@ -233,14 +234,23 @@ def litellm_completion_params(self, step: int | None = None) -> dict: def get_inference_name(self, step: int | None = None) -> str: """Return the name that should be sent to the inference endpoint. - If `inference_model_name` is provided we use that, otherwise we fall - back to the model's own `name`. - Args: - step: If provided, returns name for specific checkpoint using - the `name@step` convention. If None, returns name for - latest checkpoint (default, backwards compatible). + step: If provided, returns name for specific checkpoint. + If None, returns name for latest/default checkpoint. + + Note: + For TrainableModel with LocalBackend, vLLM serves LoRA adapters + as `model.name@step`, so this always includes the step suffix. + For ServerlessBackend, it uses W&B artifact naming conventions. """ + # If we have a registered backend with _model_inference_name, use it + # This ensures proper step handling for each backend type + if self._backend is not None and hasattr( + self._backend, "_model_inference_name" + ): + return self._backend._model_inference_name(self, step=step) + + # Fallback for non-registered models or backends without the method base_name = self.inference_model_name or self.name if step is not None: return f"{base_name}@{step}" @@ -313,16 +323,40 @@ def _log_metrics( async def log( self, - trajectories: Iterable[Trajectory | BaseException] | Iterable[TrajectoryGroup], + trajectories: ( + Iterable[Trajectory | BaseException] | Iterable[TrajectoryGroup] | None + ) = None, split: str = "val", + *, + metrics: dict[str, float] | None = None, + step: int | None = None, ) -> None: """ - Log the model's performance for an evaluation batch of trajectories or trajectory groups. + Log trajectories and/or metrics. + + Can be used in two ways: + 1. Log trajectories: `await model.log(trajectory_groups, split="train")` + 2. Log raw metrics: `await model.log(metrics={"loss": 0.5}, step=1)` + 3. Both: `await model.log(trajectory_groups, metrics=extra_metrics)` Args: - trajectories: A batch of trajectories or trajectory groups. + trajectories: A batch of trajectories or trajectory groups. Optional if + logging only metrics. split: The evaluation's split. Defaults to "val". + metrics: Optional dict of metrics to log directly (e.g., training metrics + from backend.train()). + step: Optional step number for metrics. If not provided, uses current step. """ + # Determine the step to use + if step is None: + step = await self.get_step() if self.trainable else 0 + + # If only metrics provided (no trajectories), just log them and return + if trajectories is None: + if metrics is not None: + self._log_metrics(metrics, split, step) + return + # Convert to list[TrajectoryGroup] if any(isinstance(t, Trajectory) for t in trajectories) or any( isinstance(t, BaseException) for t in trajectories @@ -335,9 +369,6 @@ async def log( else: trajectory_groups = cast(list[TrajectoryGroup], list(trajectories)) - # Get the current step - step = await self.get_step() if self.trainable else 0 - # Ensure output directories exist output_dir = self._get_output_dir() trajectories_dir = f"{output_dir}/trajectories/{split}" @@ -377,6 +408,10 @@ async def log( # Calculate average standard deviation of rewards within groups averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) + # Merge in any additional metrics passed directly + if metrics is not None: + averages.update(metrics) + # 3. Log metrics (writes to history.jsonl and wandb) self._log_metrics(averages, split, step) @@ -546,12 +581,31 @@ async def train( """ Reinforce fine-tune the model with a batch of trajectory groups. + .. deprecated:: + Use ``backend.train(model, trajectory_groups, ...)`` instead. + This method will be removed in a future version. + Args: trajectory_groups: A batch of trajectory groups. config: Fine-tuning specific configuration _config: Additional configuration that is subject to change and not yet part of the public API. Use at your own risk. """ + warnings.warn( + "model.train() is deprecated. Use backend.train(model, ...) instead.\n\n" + "Migration guide:\n" + " # Before (deprecated):\n" + " await model.train(trajectory_groups, config=TrainConfig(learning_rate=5e-6))\n\n" + " # After (recommended):\n" + " result = await backend.train(model, trajectory_groups, learning_rate=5e-6)\n" + " await model.log(trajectory_groups, metrics=result.metrics, step=result.step, split='train')\n\n" + "Key differences:\n" + " - backend.train() does NOT automatically log trajectories or metrics\n" + " - backend.train() returns a TrainResult with step, metrics, and checkpoint info\n" + " - Each backend has its own type-checked parameters (no more generic config objects)", + DeprecationWarning, + stacklevel=2, + ) groups_list = list(trajectory_groups) _config = _config or {} diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 181f8d881..53e427aef 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Iterable, Literal import warnings from openai._types import NOT_GIVEN @@ -10,7 +10,7 @@ from .. import dev from ..backend import Backend from ..trajectories import TrajectoryGroup -from ..types import TrainConfig +from ..types import ServerlessTrainResult, TrainConfig if TYPE_CHECKING: from ..model import Model, TrainableModel @@ -74,13 +74,11 @@ async def delete( assert model.id is not None, "Model ID is required" await self._client.models.delete(model_id=model.id) - def _model_inference_name( - self, model: "TrainableModel", step: int | None = None - ) -> str: + def _model_inference_name(self, model: "Model", step: int | None = None) -> str: """Return the inference name for a model checkpoint. Args: - model: The trainable model. + model: The model. step: If provided, returns name for specific checkpoint using W&B artifact versioning (e.g., :step5). If None, returns name for latest checkpoint (default, backwards compatible). @@ -129,6 +127,126 @@ async def _prepare_backend_for_training( # Note: _log() method has been moved to the Model class (frontend) # Trajectories are now saved locally by the Model.log() method + async def train( # type: ignore[override] + self, + model: "TrainableModel", + trajectory_groups: Iterable[TrajectoryGroup], + *, + # Core training parameters + learning_rate: float = 5e-6, + beta: float = 0.0, + # RL algorithm settings + ppo: bool = False, + epsilon: float | None = None, + epsilon_high: float | None = None, + # Advantage computation + advantage_balance: float = 0.0, + scale_rewards: bool = True, + # Importance sampling + importance_sampling_level: Literal[ + "token", "sequence", "average", "geometric_average" + ] = "token", + max_negative_advantage_importance_sampling_weight: float | None = None, + mask_prob_ratio: bool = False, + # Experimental parameters + kimi_k2_tau: float | None = None, + precalculate_logprobs: bool = False, + # Verbosity + verbose: bool = False, + ) -> ServerlessTrainResult: + """Train the model on the given trajectory groups. + + This is the recommended way to train models. Unlike model.train(), this + method does NOT automatically log trajectories or metrics. Call model.log() + explicitly before and/or after training if you want to log data. + + Args: + model: The trainable model to train. + trajectory_groups: Batches of trajectories to train on. + learning_rate: Learning rate for training. Defaults to 5e-6. + beta: KL penalty coefficient. Defaults to 0.0. + ppo: Whether to use PPO clipping. Defaults to False. + epsilon: Clip epsilon for importance sampling. Defaults based on ppo. + epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. + advantage_balance: Balance between negative and positive advantages + in range [-1.0, 1.0]. Defaults to 0.0 (balanced). + scale_rewards: Whether to scale rewards by standard deviation. + Defaults to True. + importance_sampling_level: Level at which to compute importance + sampling weights. Defaults to "token". + max_negative_advantage_importance_sampling_weight: Maximum weight + for negative advantage samples. + mask_prob_ratio: Whether to mask probability ratios. Defaults to False. + kimi_k2_tau: Tau parameter for Kimi K2 algorithm. + precalculate_logprobs: Whether to precalculate logprobs. + verbose: Whether to print verbose output. Defaults to False. + + Returns: + ServerlessTrainResult with step number, training metrics, and artifact name. + + Example: + # Before (deprecated): + await model.train(trajectory_groups, config=TrainConfig(learning_rate=5e-6)) + + # After (recommended): + await model.log(trajectory_groups, split="train") + result = await backend.train(model, trajectory_groups, learning_rate=5e-6) + # Optionally log training metrics: + # await model.log(metrics=result.metrics, step=result.step) + """ + groups_list = list(trajectory_groups) + + # Build config objects from explicit kwargs + config = TrainConfig(learning_rate=learning_rate, beta=beta) + dev_config: dev.TrainConfig = { + "advantage_balance": advantage_balance, + "importance_sampling_level": importance_sampling_level, + "mask_prob_ratio": mask_prob_ratio, + "ppo": ppo, + "precalculate_logprobs": precalculate_logprobs, + "scale_rewards": scale_rewards, + } + # Only include optional fields if they're set + if epsilon is not None: + dev_config["epsilon"] = epsilon + if epsilon_high is not None: + dev_config["epsilon_high"] = epsilon_high + if max_negative_advantage_importance_sampling_weight is not None: + dev_config["max_negative_advantage_importance_sampling_weight"] = ( + max_negative_advantage_importance_sampling_weight + ) + if kimi_k2_tau is not None: + dev_config["kimi_k2_tau"] = kimi_k2_tau + + # Collect metrics from training + training_metrics: list[dict[str, float]] = [] + async for metrics in self._train_model( + model, groups_list, config, dev_config, verbose + ): + training_metrics.append(metrics) + + # Aggregate metrics + avg_metrics: dict[str, float] = {} + if training_metrics: + avg_metrics = { + k: sum(d.get(k, 0) for d in training_metrics) + / sum(1 for d in training_metrics if k in d) + for k in {k for d in training_metrics for k in d} + if k != "num_gradient_steps" + } + + # Get step and artifact name + step = await self._get_step(model) + artifact_name: str | None = None + if model.entity is not None: + artifact_name = f"{model.entity}/{model.project}/{model.name}:step{step}" + + return ServerlessTrainResult( + step=step, + metrics=avg_metrics, + artifact_name=artifact_name, + ) + async def _train_model( self, model: "TrainableModel", diff --git a/src/art/types.py b/src/art/types.py index c3a82d049..df81d6842 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field from typing import Annotated, Literal from openai.types.chat.chat_completion import Choice @@ -19,3 +20,49 @@ class TrainConfig(pydantic.BaseModel): Verbosity = Literal[0, 1, 2] + + +# --------------------------------------------------------------------------- +# TrainResult classes +# --------------------------------------------------------------------------- + + +@dataclass +class TrainResult: + """Base result returned from backend.train(). + + Attributes: + step: The training step after this training call completed. + metrics: Aggregated training metrics (loss, gradient norms, etc.). + """ + + step: int + metrics: dict[str, float] = field(default_factory=dict) + + +@dataclass +class LocalTrainResult(TrainResult): + """Result from LocalBackend.train(). + + Attributes: + step: The training step after this training call completed. + metrics: Aggregated training metrics (loss, gradient norms, etc.). + checkpoint_path: Path to the saved checkpoint directory, or None if + no checkpoint was saved. + """ + + checkpoint_path: str | None = None + + +@dataclass +class ServerlessTrainResult(TrainResult): + """Result from ServerlessBackend.train(). + + Attributes: + step: The training step after this training call completed. + metrics: Aggregated training metrics (loss, gradient norms, etc.). + artifact_name: The W&B artifact name for the checkpoint + (e.g., "entity/project/model:step5"). + """ + + artifact_name: str | None = None diff --git a/tests/integration/test_multi_checkpoint_training.py b/tests/integration/test_multi_checkpoint_training.py index fd43f8b21..38c3c3c9f 100644 --- a/tests/integration/test_multi_checkpoint_training.py +++ b/tests/integration/test_multi_checkpoint_training.py @@ -15,6 +15,7 @@ import os import tempfile +from typing import Union import uuid import openai @@ -22,6 +23,7 @@ import art from art.local import LocalBackend +from art.types import LocalTrainResult, ServerlessTrainResult, TrainResult # Use a small model for fast testing DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B" @@ -59,27 +61,19 @@ async def simple_rollout( async def run_training_loop( model: art.TrainableModel, + backend: Union[LocalBackend, art.ServerlessBackend, art.TinkerBackend], num_steps: int = 1, rollouts_per_step: int = 4, -) -> list[int]: - """Run a simple training loop and return the step numbers after each train call.""" +) -> list[TrainResult]: + """Run a simple training loop and return the TrainResults from each train call.""" openai_client = model.openai_client() prompts = ["Say yes", "Say no", "Say maybe", "Say hello"] - steps_completed = [] - - async def resolve_model_name(preferred: str, fallback: str) -> str: - try: - available = [m.id async for m in openai_client.models.list()] - except Exception: - return preferred - return preferred if preferred in available else fallback + results: list[TrainResult] = [] for _ in range(num_steps): current_step = await model.get_step() - preferred_name = model.get_inference_name(step=current_step) - model_name = await resolve_model_name( - preferred_name, model.get_inference_name(step=0) - ) + # Use get_inference_name(step=current_step) to target the current checkpoint + model_name = model.get_inference_name(step=current_step) train_groups = await art.gather_trajectory_groups( [ art.TrajectoryGroup( @@ -91,13 +85,13 @@ async def resolve_model_name(preferred: str, fallback: str) -> str: for prompt in prompts ] ) - await model.train( - train_groups, - config=art.TrainConfig(learning_rate=1e-5), + result = await backend.train(model, train_groups, learning_rate=1e-5) + await model.log( + train_groups, metrics=result.metrics, step=result.step, split="train" ) - steps_completed.append(await model.get_step()) + results.append(result) - return steps_completed + return results async def _run_inference_on_step( @@ -130,8 +124,14 @@ async def test_tinker_backend(): ) try: await model.register(backend) - steps = await run_training_loop(model, num_steps=1, rollouts_per_step=2) - await _run_inference_on_step(model, step=steps[-1]) + results = await run_training_loop( + model, backend, num_steps=1, rollouts_per_step=2 + ) + # Verify TrainResult structure + assert len(results) == 1 + assert isinstance(results[0], LocalTrainResult) + assert results[0].step > 0 + await _run_inference_on_step(model, step=results[-1].step) await _run_inference_on_step(model, step=0) finally: await backend.close() @@ -153,8 +153,15 @@ async def test_local_backend(): ) try: await model.register(backend) - steps = await run_training_loop(model, num_steps=1, rollouts_per_step=2) - await _run_inference_on_step(model, step=steps[-1]) + results = await run_training_loop( + model, backend, num_steps=1, rollouts_per_step=2 + ) + # Verify TrainResult structure + assert len(results) == 1 + assert isinstance(results[0], LocalTrainResult) + assert results[0].step > 0 + assert results[0].checkpoint_path is not None + await _run_inference_on_step(model, step=results[-1].step) await _run_inference_on_step(model, step=0) finally: await backend.close() @@ -175,8 +182,15 @@ async def test_serverless_backend(): ) try: await model.register(backend) - steps = await run_training_loop(model, num_steps=1, rollouts_per_step=2) - await _run_inference_on_step(model, step=steps[-1]) + results = await run_training_loop( + model, backend, num_steps=1, rollouts_per_step=2 + ) + # Verify TrainResult structure + assert len(results) == 1 + assert isinstance(results[0], ServerlessTrainResult) + assert results[0].step > 0 + assert results[0].artifact_name is not None + await _run_inference_on_step(model, step=results[-1].step) await _run_inference_on_step(model, step=0) finally: try: diff --git a/tests/test_backend_train_api.py b/tests/test_backend_train_api.py new file mode 100644 index 000000000..bc9551175 --- /dev/null +++ b/tests/test_backend_train_api.py @@ -0,0 +1,130 @@ +"""Test the new backend.train() API with real GPU training. + +This test runs a simple yes-no-maybe training loop using the new backend-first API. + +Usage: + cd /workspace/ART && source .venv/bin/activate + python tests/test_backend_train_api.py +""" + +import asyncio +import tempfile + +import art +from art.local import LocalBackend +from art.types import LocalTrainResult + + +async def simple_rollout(client, model_name: str, prompt: str) -> art.Trajectory: + """A simple rollout function for testing.""" + messages: art.Messages = [{"role": "user", "content": prompt}] + chat_completion = await client.chat.completions.create( + messages=messages, + model=model_name, + max_tokens=10, + timeout=60, + temperature=1, + ) + choice = chat_completion.choices[0] + content = (choice.message.content or "").lower() + if "yes" in content: + reward = 1.0 + elif "no" in content: + reward = 0.5 + elif "maybe" in content: + reward = 0.25 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +async def main(): + print("=" * 60) + print("Testing new backend.train() API") + print("=" * 60) + + with tempfile.TemporaryDirectory() as tmpdir: + print(f"\nUsing temp directory: {tmpdir}") + + # Create backend and model + backend = LocalBackend(path=tmpdir) + model = art.TrainableModel( + name="test-backend-train-api", + project="api-test", + base_model="Qwen/Qwen3-0.6B", + ) + + try: + print("\n1. Registering model with backend...") + await model.register(backend) + print(" ✓ Model registered") + + # Get OpenAI client + openai_client = model.openai_client() + + # Use get_inference_name() for the correct model name + # After registration, this returns the proper name (e.g., model.name@0) + inference_name = model.get_inference_name() + print(f" Using model for inference: {inference_name}") + + print("\n2. Gathering trajectories...") + prompts = ["Say yes", "Say no", "Say maybe", "Say hello"] + train_groups = await art.gather_trajectory_groups( + [ + art.TrajectoryGroup( + [ + simple_rollout(openai_client, inference_name, prompt) + for _ in range(4) # 4 rollouts per prompt + ] + ) + for prompt in prompts + ] + ) + print(f" ✓ Gathered {len(train_groups)} trajectory groups") + + # Print some sample rewards + for i, group in enumerate(train_groups): + rewards = [t.reward for t in group] + print(f" Group {i} ({prompts[i]}): rewards = {rewards}") + + print("\n3. Training with backend.train()...") + result = await backend.train( + model, + train_groups, + learning_rate=1e-5, + verbose=True, + ) + + print("\n4. Logging trajectories and training metrics...") + await model.log( + train_groups, metrics=result.metrics, step=result.step, split="train" + ) + print(" ✓ Trajectories and metrics logged") + + print("\n5. Checking TrainResult...") + print(f" Result type: {type(result).__name__}") + print(f" Step: {result.step}") + print(f" Metrics: {result.metrics}") + + assert isinstance(result, LocalTrainResult), ( + f"Expected LocalTrainResult, got {type(result)}" + ) + print(f" Checkpoint path: {result.checkpoint_path}") + + assert result.step > 0, f"Expected step > 0, got {result.step}" + assert isinstance(result.metrics, dict), ( + f"Expected dict metrics, got {type(result.metrics)}" + ) + + print("\n" + "=" * 60) + print("✓ All checks passed! New backend.train() API works correctly.") + print("=" * 60) + + finally: + print("\nCleaning up...") + await backend.close() + print("Done!") + + +if __name__ == "__main__": + asyncio.run(main())