diff --git a/pyproject.toml b/pyproject.toml index 1a50198af..46a7e8bd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,8 +9,6 @@ dependencies = [ "typer>=0.15.2", "litellm>=1.71.1", "weave>=0.52.23", - "tinker>=0.8.1", - "tinker-cookbook>=0.1.0", "polars>=1.26.0", "tblib>=3.0.0", "nest-asyncio>=1.6.0", @@ -115,6 +113,9 @@ unused-ignore-comment = "ignore" # Allow unresolved imports for optional dependencies that may not be installed locally. # In CI, we install all optional deps so these will be resolved and type-checked. allowed-unresolved-imports = [ + # tinker deps + "tinker.**", + "tinker_cookbook.**", # backend deps "accelerate.**", "awscli.**", @@ -165,6 +166,12 @@ dev = [ "pyarrow>=15.0.0", "prek>=0.2.29", ] +tinker = [ + "fastapi>=0.128.0", + "tinker>=0.8.1", + "tinker-cookbook>=0.1.0", + "uvicorn>=0.35.0", +] [tool.uv.sources] panza = { git = "https://github.com/corbt/panza.git" } diff --git a/src/art/__init__.py b/src/art/__init__.py index b6948f514..d07d20274 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -57,7 +57,13 @@ def __init__(self, **kwargs): from .local import LocalBackend from .model import Model, TrainableModel from .serverless import ServerlessBackend -from .tinker import TinkerBackend + +try: + from .tinker import TinkerBackend + from .tinker_native import TinkerNativeBackend +except ModuleNotFoundError: + TinkerBackend = None # type: ignore[assignment] + TinkerNativeBackend = None # type: ignore[assignment] from .trajectories import Trajectory, TrajectoryGroup from .types import ( LocalTrainResult, @@ -91,9 +97,10 @@ def __init__(self, **kwargs): "retry", "TrainConfig", "TrainResult", - "TinkerBackend", "Trajectory", "TrajectoryGroup", "capture_yielded_trajectory", "yield_trajectory", ] +if TinkerBackend is not None: + __all__.extend(["TinkerBackend", "TinkerNativeBackend"]) diff --git a/src/art/dev/__init__.py b/src/art/dev/__init__.py index 45e1cdef6..9d04c26bd 100644 --- a/src/art/dev/__init__.py +++ b/src/art/dev/__init__.py @@ -4,6 +4,7 @@ InternalModelConfig, PeftArgs, TinkerArgs, + TinkerNativeArgs, TinkerTrainingClientArgs, TrainerArgs, ) @@ -16,6 +17,7 @@ "InitArgs", "PeftArgs", "TinkerArgs", + "TinkerNativeArgs", "TinkerTrainingClientArgs", "TrainerArgs", "get_openai_server_config", diff --git a/src/art/dev/model.py b/src/art/dev/model.py index 68316233f..8bd342b81 100644 --- a/src/art/dev/model.py +++ b/src/art/dev/model.py @@ -121,6 +121,7 @@ class InternalModelConfig(TypedDict, total=False): engine_args: "EngineArgs" peft_args: "PeftArgs" tinker_args: "TinkerArgs | None" + tinker_native_args: "TinkerNativeArgs | None" trainer_args: "TrainerArgs" @@ -129,6 +130,11 @@ class TinkerArgs(TypedDict, total=False): training_client_args: "TinkerTrainingClientArgs" +class TinkerNativeArgs(TypedDict, total=False): + renderer_name: Required[str] + training_client_args: "TinkerTrainingClientArgs" + + class TinkerTrainingClientArgs(TypedDict, total=False): rank: int seed: int | None diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 19f26afbe..4d46df28b 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -102,6 +102,8 @@ async def register( Args: model: An art.Model instance. """ + # Ensure model state/logging uses the backend path + model.base_path = self._path output_dir = get_model_dir(model=model, art_path=self._path) os.makedirs(output_dir, exist_ok=True) with open(f"{output_dir}/model.json", "w") as f: diff --git a/src/art/model.py b/src/art/model.py index 2fc38640b..05f55bd67 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -261,18 +261,22 @@ def _get_output_dir(self) -> str: """Get the output directory for this model.""" return f"{self.base_path}/{self.project}/models/{self.name}" - def write_state(self, state: StateType) -> None: - """Write persistent state to the model directory as JSON. + def overwrite_state(self, state: StateType) -> None: + """Overwrite persistent state in the model directory as JSON. This state is stored in `state.json` within the model's output directory and can be used to track training progress, dataset position, or any other information that should persist across runs. + Warning: + This overwrites the entire state file. Prefer `merge_state()` unless + you intentionally want to replace all existing keys. + Args: state: A dictionary of JSON-serializable values to persist. Example: - model.write_state({ + model.overwrite_state({ "step": 5, "dataset_offset": 100, "last_checkpoint_time": "2024-01-15T10:30:00", @@ -283,6 +287,45 @@ def write_state(self, state: StateType) -> None: with open(f"{output_dir}/state.json", "w") as f: json.dump(state, f, indent=2) + def write_state(self, state: StateType) -> None: + """Deprecated: use `overwrite_state()` or `merge_state()` instead.""" + warnings.warn( + "write_state() is deprecated. Use overwrite_state() or merge_state() instead.", + DeprecationWarning, + stacklevel=2, + ) + self.overwrite_state(state) + + def merge_state(self, state: StateType) -> StateType: + """Deep-merge state into the existing state and persist it. + + Args: + state: A dictionary of JSON-serializable values to merge. + + Returns: + The merged state dictionary that was persisted. + """ + existing = self.read_state() or {} + merged = self._deep_merge_dicts(existing, state) + self.overwrite_state(merged) + return cast(StateType, merged) + + @staticmethod + def _deep_merge_dicts( + base: dict[str, Any], updates: dict[str, Any] + ) -> dict[str, Any]: + merged = dict(base) + for key, value in updates.items(): + if ( + key in merged + and isinstance(merged[key], dict) + and isinstance(value, dict) + ): + merged[key] = Model._deep_merge_dicts(merged[key], value) + else: + merged[key] = value + return merged + def read_state(self) -> StateType | None: """Read persistent state from the model directory. diff --git a/src/art/pipeline_trainer/__init__.py b/src/art/pipeline_trainer/__init__.py new file mode 100644 index 000000000..3e0d0f9ce --- /dev/null +++ b/src/art/pipeline_trainer/__init__.py @@ -0,0 +1,13 @@ +from .status import StatusReporter +from .trainer import PipelineTrainer, make_group_rollout_fn +from .types import EvalFn, RolloutFn, ScenarioT, SingleRolloutFn + +__all__ = [ + "PipelineTrainer", + "make_group_rollout_fn", + "StatusReporter", + "RolloutFn", + "SingleRolloutFn", + "EvalFn", + "ScenarioT", +] diff --git a/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py new file mode 100644 index 000000000..70b413b66 --- /dev/null +++ b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +import json +import os +from pathlib import Path +import re +from typing import Any, cast + +from dotenv import load_dotenv +from openai.types.chat.chat_completion_tool_choice_option_param import ( + ChatCompletionToolChoiceOptionParam, +) +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +import polars as pl + +import art + +from . import PipelineTrainer, make_group_rollout_fn + +Scenario = dict[str, Any] + + +@dataclass +class PipelineConfig: + temperature: float + eval_temperature: float + max_tokens: int + + +TOOL_NAME = "make_guess" +SECRET_BITS = "110000110100101011111010101011" +SECRET_LEN = len(SECRET_BITS) + +TOOLS: list[ChatCompletionToolParam] = [ + cast( + ChatCompletionToolParam, + { + "type": "function", + "function": { + "name": TOOL_NAME, + "description": "Submit a binary guess for the secret string.", + "parameters": { + "type": "object", + "properties": { + "guess": { + "type": "string", + "description": ( + "A binary string of length " + f"{SECRET_LEN} consisting of 0 and 1." + ), + } + }, + "required": ["guess"], + "additionalProperties": False, + }, + }, + }, + ) +] +TOOL_CHOICE: ChatCompletionToolChoiceOptionParam = { + "type": "function", + "function": {"name": TOOL_NAME}, +} + +SYSTEM_PROMPT = ( + "You are playing a prefix-guessing game. You must call the tool " + f"{TOOL_NAME} exactly once with your best guess. The argument must be a " + f"{SECRET_LEN}-character string of only 0 and 1. Do not output any other text. " + "Your reward is the length of the shared prefix with the secret string." +) +USER_PROMPT = "Call the tool with your best binary guess." + + +def is_valid_guess(guess: str) -> bool: + return all(ch in {"0", "1"} for ch in guess) + + +def shared_prefix_len(guess: str, secret: str) -> int: + matched = 0 + for guessed, actual in zip(guess, secret): + if guessed != actual: + break + matched += 1 + return matched + + +def _parse_guess_args(arguments: str | None) -> str | None: + if not arguments: + return None + text = arguments.strip() + if not text: + return None + payload: Any | None = None + try: + payload = json.loads(text) + except json.JSONDecodeError: + start = text.find("{") + end = text.rfind("}") + if start != -1 and end != -1 and end > start: + try: + payload = json.loads(text[start : end + 1]) + except json.JSONDecodeError: + payload = None + + if isinstance(payload, dict): + guess = payload.get("guess") + if isinstance(guess, str): + return guess + if guess is not None: + return str(guess) + + match = re.search(r'guess\s*[:=]\s*"([^"]*)"', text) + if match: + return match.group(1) + match = re.search(r"guess\s*[:=]\s*'([^']*)'", text) + if match: + return match.group(1) + if re.fullmatch(r"[01\s]+", text): + return text + return None + + +def _tool_name_and_args(tool_call: Any) -> tuple[str | None, str | None]: + if hasattr(tool_call, "function"): + function = tool_call.function + return getattr(function, "name", None), getattr(function, "arguments", None) + if isinstance(tool_call, dict): + func = tool_call.get("function") or {} + return func.get("name"), func.get("arguments") + return None, None + + +def extract_guess(choice: Any) -> tuple[str | None, str]: + tool_calls = getattr(choice.message, "tool_calls", None) or [] + for tool_call in tool_calls: + name, args = _tool_name_and_args(tool_call) + if name != TOOL_NAME: + continue + guess = _parse_guess_args(args) + if guess is not None: + return guess, "tool_call" + + return None, "missing" + + +def get_model_output_dir(model: art.TrainableModel) -> Path: + return Path(model.base_path) / model.project / "models" / model.name + + +def print_history_summary(model: art.TrainableModel, tail: int = 5) -> None: + history_path = get_model_output_dir(model) / "history.jsonl" + if not history_path.exists(): + print(f"No history found at {history_path}") + return + + rows = pl.read_ndjson(str(history_path)).to_dicts() + + train_rows = [row for row in rows if "train/reward" in row] + print("\nRecent training metrics:") + for row in train_rows[-tail:]: + step = row["step"] + reward = row["train/reward"] + std_dev = row["train/reward_std_dev"] + discarded = row["train/discarded_stale_samples"] + off_policy = row["train/steps_off_policy"] + print( + f" step={step} reward={reward} std={std_dev} " + f"discarded={discarded} off_policy={off_policy}" + ) + + +async def main() -> None: + load_dotenv() + + base_model = os.environ.get( + "BASE_MODEL", "Qwen/Qwen3-4B-Instruct-2507" + ) # Qwen/Qwen3-30B-A3B-Instruct-2507 + model_name = os.environ.get("MODEL_NAME", "pipeline-binary-prefix-tool") + project = os.environ.get("PROJECT", "binary-prefix-tool-pipeline") + art_path = os.environ.get("ART_PATH") + + min_batch_size = int(os.environ.get("MIN_BATCH_SIZE", "4")) + num_rollout_workers = int(os.environ.get("NUM_ROLLOUT_WORKERS", "8")) + rollouts_per_scenario = int(os.environ.get("ROLLOUTS_PER_SCENARIO", "8")) + max_steps_off_policy = int(os.environ.get("MAX_STEPS_OFF_POLICY", "6")) + max_batch_size_env = os.environ.get("MAX_BATCH_SIZE") + max_batch_size = int(max_batch_size_env) if max_batch_size_env else None + eval_every_n_steps = int(os.environ.get("EVAL_EVERY_N_STEPS", "2")) + eval_step_0 = os.environ.get("EVAL_STEP_0", "1") == "1" + max_steps = int(os.environ.get("MAX_STEPS", "10")) + save_checkpoint = os.environ.get("SAVE_CHECKPOINT", "0") == "1" + resume_env = os.environ.get("RESUME") + resume = (resume_env == "1") if resume_env is not None else save_checkpoint + if resume and not save_checkpoint: + print("RESUME=1 but SAVE_CHECKPOINT=0; disabling resume for a clean run.") + resume = False + + temperature = float(os.environ.get("ROLLOUT_TEMPERATURE", "1.0")) + eval_temperature = float(os.environ.get("EVAL_TEMPERATURE", "0.0")) + max_tokens = int(os.environ.get("MAX_TOKENS", "300")) + request_timeout = float(os.environ.get("REQUEST_TIMEOUT", "60")) + log_interval_seconds = float(os.environ.get("STATUS_LOG_INTERVAL_SECONDS", "60")) + + internal_config: art.dev.InternalModelConfig | None = None + lora_rank = os.environ.get("LORA_RANK") + if lora_rank is not None: + internal_config = { + "tinker_native_args": { + "renderer_name": os.environ.get("RENDERER_NAME", "qwen3_instruct"), + "training_client_args": {"rank": int(lora_rank)}, + } + } + + backend = art.TinkerNativeBackend(path=art_path) + model = art.TrainableModel( + name=model_name, + project=project, + base_model=base_model, + _internal_config=internal_config, + report_metrics=[], # Disable wandb logging + ) + await model.register(backend) + + openai_client = model.openai_client() + + async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory: + """Core rollout logic used by both training and eval.""" + messages: art.Messages = scenario["messages"] + response = await openai_client.chat.completions.create( + messages=messages, + model=model.name, + max_tokens=max_tokens, + timeout=request_timeout, + temperature=temp, + tools=TOOLS, + tool_choice=TOOL_CHOICE, + ) + choice = response.choices[0] + raw_guess, source = extract_guess(choice) + guess = raw_guess or "" + valid_guess = is_valid_guess(guess) + prefix_len = shared_prefix_len(guess, SECRET_BITS) if valid_guess else 0 + reward = float(prefix_len) + metrics = { + "prefix_len": prefix_len, + "guess_len": len(guess), + "secret_len": SECRET_LEN, + "valid_guess": 1.0 if valid_guess else 0.0, + "tool_call_count": float( + len(getattr(choice.message, "tool_calls", None) or []) + ), + "tool_call_found": 1.0 if source != "missing" else 0.0, + "tool_call_structured": 1.0 if source == "tool_call" else 0.0, + } + return art.Trajectory( + messages_and_choices=[*messages, choice], + tools=TOOLS, + reward=reward, + metrics=metrics, + ) + + async def single_rollout( + _model: art.TrainableModel, + scenario: Scenario, + _config: PipelineConfig, + ) -> art.Trajectory: + return await do_rollout(scenario, temperature) + + rollout_fn = make_group_rollout_fn(single_rollout, n=rollouts_per_scenario) + + last_eval: dict[str, float | None] = {"avg_reward": None} + + async def eval_fn( + _model: art.TrainableModel, _step: int, _config: PipelineConfig + ) -> list[art.Trajectory]: + tasks = [do_rollout(build_scenario(), eval_temperature)] + results = await asyncio.gather(*tasks, return_exceptions=True) + trajectories = [r for r in results if isinstance(r, art.Trajectory)] + if trajectories: + avg_reward = sum(t.reward for t in trajectories) / len(trajectories) + last_eval["avg_reward"] = avg_reward + return trajectories + + scenario_count = int(os.environ.get("SCENARIO_COUNT", "1000")) + scenario_count = max(1, scenario_count) + + def build_scenario() -> Scenario: + return { + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": USER_PROMPT}, + ], + } + + async def scenario_iter(): + for i in range(scenario_count): + scenario = build_scenario() + scenario["metadata"] = {"scenario_idx": i} + yield scenario + + config = PipelineConfig( + temperature=temperature, + eval_temperature=eval_temperature, + max_tokens=max_tokens, + ) + + trainer = PipelineTrainer( + model=model, + backend=backend, + rollout_fn=rollout_fn, + scenarios=scenario_iter(), + config=config, + eval_fn=eval_fn, + num_rollout_workers=num_rollout_workers, + min_batch_size=min_batch_size, + max_steps_off_policy=max_steps_off_policy, + max_batch_size=max_batch_size, + learning_rate=float(os.environ.get("LEARNING_RATE", "1e-4")), + log_interval_seconds=log_interval_seconds, + eval_every_n_steps=eval_every_n_steps, + eval_step_0=eval_step_0, + save_checkpoint=save_checkpoint, + resume=resume, + max_steps=max_steps, + total_scenarios=scenario_count, + ) + + print( + "Starting pipeline trainer test: " + f"max_steps={max_steps} scenarios={scenario_count} " + f"rollouts_per_scenario={rollouts_per_scenario} " + f"secret_len={SECRET_LEN}" + ) + await trainer.train() + print("Training completed.") + if last_eval["avg_reward"] is not None: + print(f"Last eval avg reward: {last_eval['avg_reward']:.3f}") + + print_history_summary(model, tail=5) + await backend.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/art/pipeline_trainer/state.py b/src/art/pipeline_trainer/state.py new file mode 100644 index 000000000..46f8a9cb0 --- /dev/null +++ b/src/art/pipeline_trainer/state.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field + + +@dataclass +class PipelineState: + """Shared state across pipeline stages.""" + + # Policy versioning + policy_version: int = 0 + next_training_step: int = 0 + + # Scenario tracking + scenario_offset: int = 0 + total_scenarios_consumed: int = 0 + last_eval_step: int = 0 + + # Metrics + discarded_stale_samples: int = 0 + + # Synchronization + policy_updated: asyncio.Condition = field(default_factory=asyncio.Condition) + done: bool = False diff --git a/src/art/pipeline_trainer/status.py b/src/art/pipeline_trainer/status.py new file mode 100644 index 000000000..cb58bdb3e --- /dev/null +++ b/src/art/pipeline_trainer/status.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +import math +import shutil +import sys +import time +from typing import Callable, cast + +from tqdm import tqdm + +from art import TrajectoryGroup + + +class StatusReporter: + def __init__( + self, + *, + get_scenario_offset: Callable[[], int], + log_interval_seconds: float = 60.0, + status_ewa_alpha: float = 0.2, + total_scenarios: int | None = None, + num_workers: int = 1, + ) -> None: + if log_interval_seconds <= 0: + raise ValueError("log_interval_seconds must be > 0") + if not 0 < status_ewa_alpha <= 1: + raise ValueError("status_ewa_alpha must be in (0, 1]") + if total_scenarios is not None and total_scenarios < 0: + raise ValueError("total_scenarios must be >= 0") + if num_workers <= 0: + raise ValueError("num_workers must be > 0") + + self._get_scenario_offset = get_scenario_offset + self._log_interval_seconds = log_interval_seconds + self._status_ewa_alpha = status_ewa_alpha + self._total_scenarios = total_scenarios + self._num_workers = num_workers + + self._current_step: int | None = None + self._rolling_out = 0 + self._queued = 0 + self._training = 0 + self._trained = 0 + self._stale = 0 + self._zero_var = 0 + self._errored = 0 + + self._train_reward_ewa: float | None = None + self._seconds_ewa: float | None = None + self._avg_std_dev_ewa: float | None = None + + self._last_val_step: int | None = None + self._last_val_reward: float | None = None + self._val_running_step: int | None = None + + self._tqdm: tqdm | None = None + self._started = False + self._last_log_time = 0.0 + self._last_refresh_time = 0.0 + self._refresh_interval_seconds = 0.25 + + def start(self, *, initial_step: int | None = None) -> None: + if self._started: + return + self._started = True + if initial_step is not None: + self._current_step = initial_step + self._last_log_time = time.monotonic() - self._log_interval_seconds + self._last_refresh_time = 0.0 + if sys.stdout.isatty(): + self._tqdm = tqdm( + total=self._total_scenarios, + bar_format="{desc}", + dynamic_ncols=True, + leave=False, + file=sys.stdout, + ) + self._refresh_status(force=True) + + def close(self) -> None: + if self._tqdm is not None: + self._tqdm.close() + self._tqdm = None + self._started = False + + def flush(self) -> None: + self.log_if_due(force=True) + + def set_step(self, step: int) -> None: + self._current_step = step + self._refresh_status() + + def log_if_due(self, *, force: bool = False) -> None: + if not self._started: + return + now = time.monotonic() + if not force and (now - self._last_log_time) < self._log_interval_seconds: + return + self._last_log_time = now + self._write_log_line(self._format_full_log()) + + def note_rollout_started(self) -> None: + self._rolling_out += 1 + self._refresh_status() + + def note_rollout_finished(self, *, errored: bool) -> None: + if self._rolling_out > 0: + self._rolling_out -= 1 + if errored: + self._errored += 1 + self._refresh_status() + + def note_group_enqueued(self, _group: TrajectoryGroup) -> None: + self._queued += 1 + self._refresh_status() + + def note_group_dequeued(self, _group: TrajectoryGroup) -> None: + if self._queued > 0: + self._queued -= 1 + self._refresh_status() + + def note_stale(self, count: int) -> None: + if count > 0: + self._stale += count + self._refresh_status() + + def note_zero_variance_discarded(self, count: int) -> None: + if count > 0: + self._zero_var += count + self._refresh_status() + + def note_training_start(self, group_count: int) -> None: + self._training = group_count + self._refresh_status() + + def note_training_end(self) -> None: + self._training = 0 + self._refresh_status() + + def note_training_batch( + self, batch: list[TrajectoryGroup], *, step: int, step_seconds: float + ) -> None: + zero_variance_groups = self._count_zero_variance_groups(batch) + trainable_groups = len(batch) - zero_variance_groups + avg_std = self._compute_batch_avg_std_dev(batch) + avg_reward = self._compute_batch_avg_reward(batch) + + self._current_step = step + self._trained += trainable_groups + self._zero_var += zero_variance_groups + + if avg_reward is not None: + self._train_reward_ewa = self._update_ewa( + self._train_reward_ewa, avg_reward + ) + self._seconds_ewa = self._update_ewa(self._seconds_ewa, step_seconds) + self._avg_std_dev_ewa = self._update_ewa(self._avg_std_dev_ewa, avg_std) + self._refresh_status(force=True) + + def note_val_started(self, step: int) -> None: + self._val_running_step = step + self._refresh_status(force=True) + + def note_val_finished(self, step: int, reward: float | None) -> None: + self._last_val_step = step + self._last_val_reward = reward + if self._val_running_step == step: + self._val_running_step = None + self._refresh_status(force=True) + + def _build_snapshot(self) -> dict[str, object]: + remaining = None + if self._total_scenarios is not None: + remaining = max(self._total_scenarios - self._get_scenario_offset(), 0) + return { + "step": self._current_step, + "remaining": remaining, + "rolling": self._rolling_out, + "workers": self._num_workers, + "queued": self._queued, + "training": self._training, + "trained": self._trained, + "zero_var": self._zero_var, + "stale": self._stale, + "errored": self._errored, + "discarded": self._zero_var + self._stale + self._errored, + "train_reward_ewa": self._train_reward_ewa, + "train_seconds_ewa": self._seconds_ewa, + "train_avg_std_ewa": self._avg_std_dev_ewa, + "val_step": self._last_val_step, + "val_reward": self._last_val_reward, + "val_running": self._val_running_step, + } + + def _format_condensed_line(self) -> str: + snapshot = self._build_snapshot() + trained = cast(int, snapshot["trained"]) + discarded = cast(int, snapshot["discarded"]) + remaining = snapshot["remaining"] + + scenarios_fields = [ + "scenarios", + f"tr={trained}", + ] + if remaining is not None: + scenarios_fields.append(f"r={self._fmt_int_compact(remaining)}") + scenarios_fields.extend( + [ + f"q={self._fmt_int_compact(snapshot['queued'])}", + f"b={self._fmt_int_compact(snapshot['training'])}", + f"d={discarded}", + ] + ) + + train_fields = [ + "train", + f"s={self._fmt_int_compact(snapshot['step'])}", + f"r={self._fmt_float_compact(snapshot['train_reward_ewa'], 2)}", + f"dt={self._fmt_float_compact(snapshot['train_seconds_ewa'], 1)}", + f"sd={self._fmt_float_compact(snapshot['train_avg_std_ewa'], 2)}", + ] + + val_run = "y" if snapshot["val_running"] is not None else "n" + val_fields = [ + "val", + f"r={self._fmt_float_compact(snapshot['val_reward'], 2)}", + f"act={val_run}", + ] + + def build_line() -> str: + return " ".join( + [ + f"scenarios[{' '.join(scenarios_fields[1:])}]", + f"train[{' '.join(train_fields[1:])}]", + f"val[{' '.join(val_fields[1:])}]", + ] + ) + + line = build_line() + max_width = shutil.get_terminal_size(fallback=(120, 20)).columns + if len(line) <= max_width: + return line + + train_fields = [field for field in train_fields if not field.startswith("sd=")] + line = build_line() + if len(line) <= max_width: + return line + + val_fields = [field for field in val_fields if not field.startswith("act=")] + line = build_line() + if len(line) <= max_width: + return line + + if remaining is not None: + scenarios_fields = [ + field for field in scenarios_fields if not field.startswith("r=") + ] + return build_line() + + def _format_full_log(self) -> str: + snapshot = self._build_snapshot() + scenarios_fields = [ + "scenarios", + f"trained={self._fmt_int(snapshot['trained'])}", + ] + if snapshot["remaining"] is not None: + scenarios_fields.append(f"remaining={self._fmt_int(snapshot['remaining'])}") + scenarios_fields.extend( + [ + f"queued={self._fmt_int(snapshot['queued'])}", + f"training={self._fmt_int(snapshot['training'])}", + ( + "discarded[" + f"total={self._fmt_int(snapshot['discarded'])} " + f"0_var={self._fmt_int(snapshot['zero_var'])} " + f"stale={self._fmt_int(snapshot['stale'])} " + f"errored={self._fmt_int(snapshot['errored'])}" + "]" + ), + f"rollouts={snapshot['rolling']}/{snapshot['workers']}", + ] + ) + + train_fields = [ + "train", + f"step={self._fmt_int(snapshot['step'])}", + f"reward={self._fmt_float(snapshot['train_reward_ewa'], 3)}", + f"step_seconds={self._fmt_float(snapshot['train_seconds_ewa'], 2)}", + f"avg_std={self._fmt_float(snapshot['train_avg_std_ewa'], 3)}", + ] + + val_run = "yes" if snapshot["val_running"] is not None else "no" + val_fields = [ + "val", + f"reward={self._fmt_float(snapshot['val_reward'], 3)}", + f"active={val_run}", + ] + if snapshot["val_step"] is not None: + val_fields.append(f"step={snapshot['val_step']}") + if snapshot["val_running"] is not None: + val_fields.append(f"active_step={snapshot['val_running']}") + + return "[status] " + " ".join( + [ + f"scenarios[{' '.join(scenarios_fields[1:])}]", + f"train[{' '.join(train_fields[1:])}]", + f"val[{' '.join(val_fields[1:])}]", + ] + ) + + def _write_log_line(self, line: str) -> None: + if self._tqdm is not None: + self._tqdm.write(line) + else: + print(line) + + def _refresh_status(self, *, force: bool = False) -> None: + if self._tqdm is None: + return + now = time.monotonic() + if ( + not force + and (now - self._last_refresh_time) < self._refresh_interval_seconds + ): + return + self._tqdm.set_description_str(self._format_condensed_line()) + self._last_refresh_time = now + + def _count_zero_variance_groups(self, batch: list[TrajectoryGroup]) -> int: + return sum(1 for group in batch if self._group_zero_variance(group)) + + def _group_zero_variance(self, group: TrajectoryGroup) -> bool: + rewards = [t.reward for t in group.trajectories] + if len(rewards) <= 1: + return True + first = rewards[0] + return all(abs(r - first) <= 1e-12 for r in rewards[1:]) + + def _compute_group_std_dev(self, group: TrajectoryGroup) -> float: + rewards = [t.reward for t in group.trajectories] + if len(rewards) <= 1: + return 0.0 + mean = sum(rewards) / len(rewards) + variance = sum((r - mean) ** 2 for r in rewards) / len(rewards) + return math.sqrt(variance) + + def _compute_batch_avg_std_dev(self, batch: list[TrajectoryGroup]) -> float: + if not batch: + return 0.0 + std_devs = [self._compute_group_std_dev(group) for group in batch] + return sum(std_devs) / len(std_devs) + + def _compute_batch_avg_reward(self, batch: list[TrajectoryGroup]) -> float | None: + rewards = [t.reward for group in batch for t in group.trajectories] + if not rewards: + return None + return sum(rewards) / len(rewards) + + def _update_ewa(self, previous: float | None, new_value: float) -> float: + if previous is None: + return new_value + alpha = self._status_ewa_alpha + return alpha * new_value + (1 - alpha) * previous + + @staticmethod + def _format_count(value: int) -> str: + if value >= 1_000_000: + return StatusReporter._format_scaled(value, 1_000_000, "m") + if value >= 1_000: + return StatusReporter._format_scaled(value, 1_000, "k") + return str(value) + + @staticmethod + def _format_scaled(value: int, scale: int, suffix: str) -> str: + scaled = value / scale + text = f"{scaled:.1f}" + if text.endswith(".0"): + text = text[:-2] + return f"{text}{suffix}" + + @staticmethod + def _fmt_int(value: object) -> str: + if value is None: + return "n/a" + return str(value) + + @staticmethod + def _fmt_int_compact(value: object) -> str: + if value is None: + return "na" + return str(value) + + @staticmethod + def _fmt_float(value: object, decimals: int) -> str: + if value is None: + return "n/a" + if not isinstance(value, (int, float)): + return str(value) + return f"{value:.{decimals}f}" + + @staticmethod + def _fmt_float_compact(value: object, decimals: int) -> str: + if value is None: + return "na" + if not isinstance(value, (int, float)): + return str(value) + return f"{value:.{decimals}f}" diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py new file mode 100644 index 000000000..4bb062d4c --- /dev/null +++ b/src/art/pipeline_trainer/trainer.py @@ -0,0 +1,726 @@ +from __future__ import annotations + +import asyncio +import os +import signal +import time +from typing import Any, AsyncIterator, Generic, Iterable, TypeVar, cast + +T = TypeVar("T") + +import art +from art import TrajectoryGroup + +from .state import PipelineState +from .status import StatusReporter +from .types import ConfigT, EvalFn, RolloutFn, ScenarioT, SingleRolloutFn # noqa: F401 + +PIPELINE_STATE_KEY = "_pipeline_trainer" + + +def _to_async_iterator(iterable: Iterable[T] | AsyncIterator[T]) -> AsyncIterator[T]: + """Convert a sync Iterable to an AsyncIterator, or pass through if already async.""" + if isinstance(iterable, AsyncIterator): + return iterable + + async def _iter(): + for item in iterable: + yield item + + return _iter() + + +def make_group_rollout_fn( + single_rollout_fn: SingleRolloutFn[ScenarioT, ConfigT], + n: int = 4, +) -> RolloutFn[ScenarioT, ConfigT]: + """Create a RolloutFn from a SingleRolloutFn by running it N times in parallel.""" + + async def group_rollout( + model: art.TrainableModel, + scenario: ScenarioT, + config: ConfigT, + ) -> TrajectoryGroup: + if n <= 0: + return TrajectoryGroup([]) + results = await asyncio.gather( + *[single_rollout_fn(model, scenario, config) for _ in range(n)], + return_exceptions=True, + ) + return TrajectoryGroup(results) + + return group_rollout + + +class PipelineTrainer(Generic[ScenarioT, ConfigT]): + """Async 3-stage pipeline for rollouts, training, and eval.""" + + def __init__( + self, + model: art.TrainableModel, + backend: art.Backend, + rollout_fn: RolloutFn[ScenarioT, ConfigT], + scenarios: AsyncIterator[ScenarioT] | Iterable[ScenarioT], + config: ConfigT, + eval_fn: EvalFn[ConfigT] | None = None, + *, + # Pipeline settings + num_rollout_workers: int = 16, + min_batch_size: int = 4, + max_batch_size: int | None = None, + max_steps_off_policy: int = 4, + queue_maxsize: int | None = None, + # Training + learning_rate: float = 1e-5, + loss_fn: str = "cispo", + loss_fn_config: dict | None = None, + normalize_advantages: bool = True, + adam_params: object | None = None, + max_steps: int | None = None, + # Discard handling + discard_queue_multiplier: int = 100, + # Status output + log_interval_seconds: float = 60.0, + status_ewa_alpha: float = 0.2, + total_scenarios: int | None = None, + # Eval/Checkpointing + eval_every_n_steps: int = 20, + eval_step_0: bool = True, + save_checkpoint: bool = True, + # Resumption + resume: bool = True, + ) -> None: + if num_rollout_workers <= 0: + raise ValueError("num_rollout_workers must be > 0") + if min_batch_size <= 0: + raise ValueError("min_batch_size must be > 0") + if max_batch_size is not None and max_batch_size <= 0: + raise ValueError("max_batch_size must be > 0") + if max_batch_size is not None and max_batch_size < min_batch_size: + raise ValueError("max_batch_size must be >= min_batch_size") + if max_steps_off_policy < 0: + raise ValueError("max_steps_off_policy must be >= 0") + if queue_maxsize is not None and queue_maxsize <= 0: + raise ValueError("queue_maxsize must be > 0") + if eval_every_n_steps < 0: + raise ValueError("eval_every_n_steps must be >= 0") + if max_steps is not None and max_steps < 0: + raise ValueError("max_steps must be >= 0") + if log_interval_seconds <= 0: + raise ValueError("log_interval_seconds must be > 0") + if discard_queue_multiplier <= 0: + raise ValueError("discard_queue_multiplier must be > 0") + self.model = model + self.backend = backend + self.rollout_fn = rollout_fn + self.config = config + self.eval_fn = eval_fn + self.num_rollout_workers = num_rollout_workers + self.min_batch_size = min_batch_size + self.max_batch_size = ( + max_batch_size if max_batch_size is not None else 10 * min_batch_size + ) + self.max_steps_off_policy = max_steps_off_policy + self.queue_maxsize = queue_maxsize + self.learning_rate = learning_rate + self.loss_fn = loss_fn + self.loss_fn_config = loss_fn_config + self.normalize_advantages = normalize_advantages + self.adam_params = adam_params + self.max_steps = max_steps + self._status_log_interval_seconds = log_interval_seconds + self.eval_every_n_steps = eval_every_n_steps + self.eval_step_0 = eval_step_0 + self.save_checkpoint = save_checkpoint + self.resume = resume + self.discard_queue_multiplier = discard_queue_multiplier + self._discard_queue: list[TrajectoryGroup] = [] + self._discard_queue_limit = discard_queue_multiplier * min_batch_size + self._collapse_triggered = False + + self.state = PipelineState() + self._scenario_lock = asyncio.Lock() + self._scenario_iter: AsyncIterator[ScenarioT] | None = _to_async_iterator( + scenarios + ) + self._output_queue: asyncio.Queue[TrajectoryGroup | None] | None = None + self._eval_queue: asyncio.Queue[int] | None = None + self._status = StatusReporter( + get_scenario_offset=lambda: self.state.scenario_offset, + log_interval_seconds=log_interval_seconds, + status_ewa_alpha=status_ewa_alpha, + total_scenarios=total_scenarios, + num_workers=num_rollout_workers, + ) + + async def train(self, *, handle_signals: bool = True) -> None: + """Run the training pipeline over the configured scenario iterator.""" + start_step = await self.model.get_step() + pipeline_state = self._read_pipeline_state() if self.resume else {} + scenario_offset = int(pipeline_state.get("scenario_offset", 0) or 0) + last_eval_step = int(pipeline_state.get("last_eval_step", 0) or 0) + stored_step = pipeline_state.get("training_step") + + if stored_step is not None and int(stored_step) != start_step: + print( + "Warning: pipeline trainer state step does not match backend step " + f"({stored_step} != {start_step}); using backend step." + ) + + self.state.policy_version = start_step + self.state.next_training_step = start_step + self.state.scenario_offset = scenario_offset + self.state.total_scenarios_consumed = int( + pipeline_state.get("total_scenarios_consumed", scenario_offset) or 0 + ) + self.state.last_eval_step = last_eval_step + + if scenario_offset > 0 and self._scenario_iter is not None: + skipped = await self._skip_scenarios(self._scenario_iter, scenario_offset) + self.state.scenario_offset = skipped + self.state.total_scenarios_consumed = skipped + + queue_maxsize = ( + self.queue_maxsize + if self.queue_maxsize is not None + else max(1, self.max_steps_off_policy * self.max_batch_size) + ) + self._output_queue = asyncio.Queue(maxsize=queue_maxsize) + self._eval_queue = asyncio.Queue() + + if self.eval_fn is not None and self.eval_step_0 and start_step == 0: + await self._eval_queue.put(start_step) + self.state.last_eval_step = start_step + self._persist_state(start_step) + + self._status.start(initial_step=start_step) + loop = asyncio.get_running_loop() + stop_requested = False + installed_handlers: list[tuple[str, signal.Signals]] = [] + original_handlers: dict[signal.Signals, object] = {} + + def _request_stop(sig: signal.Signals) -> None: + nonlocal stop_requested + if stop_requested: + return + stop_requested = True + print(f"Shutdown requested ({sig.name}); finishing current work...") + self.request_stop() + + def _sync_signal_handler(signum: int, _frame: object | None) -> None: + _request_stop(signal.Signals(signum)) + + if handle_signals: + for sig in (signal.SIGINT, signal.SIGTERM): + original_handlers[sig] = signal.getsignal(sig) + try: + loop.add_signal_handler(sig, _request_stop, sig) + installed_handlers.append(("loop", sig)) + except (NotImplementedError, RuntimeError): + try: + signal.signal(sig, _sync_signal_handler) + installed_handlers.append(("signal", sig)) + except (ValueError, RuntimeError): + continue + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(self._rollout_stage(), name="rollout_stage") + tg.create_task(self._training_stage(), name="training_stage") + tg.create_task(self._eval_stage(), name="eval_stage") + tg.create_task(self._status_loop(), name="status_loop") + except* Exception as eg: + for exc in eg.exceptions: + if not isinstance(exc, asyncio.CancelledError): + print(f"Pipeline stage failed: {exc}") + raise + finally: + if handle_signals: + for mode, sig in installed_handlers: + if mode == "loop": + try: + loop.remove_signal_handler(sig) + except (NotImplementedError, RuntimeError): + pass + try: + previous = original_handlers.get(sig) + if previous is not None: + signal.signal(sig, cast(signal.Handlers, previous)) + except (ValueError, RuntimeError): + pass + self._status.flush() + self._status.close() + + def request_stop(self) -> None: + """Request a clean shutdown of the pipeline stages.""" + if self.state.done: + return + self.state.done = True + + async def _notify_policy() -> None: + async with self.state.policy_updated: + self.state.policy_updated.notify_all() + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is None: + return + + loop.create_task(_notify_policy()) + if self._output_queue is not None: + try: + self._output_queue.put_nowait(None) + except asyncio.QueueFull: + loop.create_task(self._output_queue.put(None)) + + async def _skip_scenarios( + self, scenarios: AsyncIterator[ScenarioT], count: int + ) -> int: + skipped = 0 + while skipped < count: + try: + await anext(scenarios) + except StopAsyncIteration: + break + skipped += 1 + if skipped < count: + print( + f"Warning: scenario iterator exhausted early while skipping " + f"(skipped {skipped}/{count})." + ) + return skipped + + async def _get_next_scenario(self) -> ScenarioT | None: + if self._scenario_iter is None: + return None + async with self._scenario_lock: + try: + scenario = await anext(self._scenario_iter) + except StopAsyncIteration: + return None + self.state.scenario_offset += 1 + self.state.total_scenarios_consumed += 1 + return scenario + + async def _wait_for_policy(self) -> None: + async with self.state.policy_updated: + while ( + not self.state.done + and self.state.policy_version + < self.state.next_training_step - self.max_steps_off_policy + ): + await self.state.policy_updated.wait() + + async def _rollout_worker(self, worker_id: int) -> None: + assert self._output_queue is not None + while not self.state.done: + scenario = await self._get_next_scenario() + if scenario is None: + break + self._status.note_rollout_started() + errored = False + try: + await self._wait_for_policy() + if self.state.done: + break + + initial_version = self.state.policy_version + + group = await self.rollout_fn(self.model, scenario, self.config) + if not isinstance(group, TrajectoryGroup): + errored = True + continue + self._apply_scenario_metadata(group, scenario) + self._apply_policy_versions( + group, + initial_version=initial_version, + final_version=self.state.policy_version, + ) + if self.state.done: + break + await self._put_output_group(group) + except asyncio.CancelledError: + raise + except Exception as exc: + errored = True + print(f"Worker {worker_id}: rollout failed: {exc}") + finally: + self._status.note_rollout_finished(errored=errored) + + async def _rollout_stage(self) -> None: + async with asyncio.TaskGroup() as tg: + for i in range(self.num_rollout_workers): + tg.create_task(self._rollout_worker(i)) + if not self.state.done and self._output_queue is not None: + try: + self._output_queue.put_nowait(None) + except asyncio.QueueFull: + await self._output_queue.put(None) + + async def _training_stage(self) -> None: + if self._output_queue is None: + return + + current_step = self.state.next_training_step + stop_at_step = ( + current_step + self.max_steps if self.max_steps is not None else None + ) + if stop_at_step is not None and current_step >= stop_at_step: + self.state.done = True + self._persist_state(current_step) + async with self.state.policy_updated: + self.state.policy_updated.notify_all() + return + stop_after_batch = False + + while True: + if stop_at_step is not None and current_step >= stop_at_step: + break + step_start = time.monotonic() + batch, discarded, saw_sentinel = await self._collect_batch(current_step) + self.state.discarded_stale_samples += discarded + if discarded: + self._status.note_stale(discarded) + if not batch: + break + + expected_step = current_step + 1 + should_eval_step = self._should_eval_step(expected_step) + should_checkpoint = self.save_checkpoint and should_eval_step + + async with self.state.policy_updated: + self.state.next_training_step = expected_step + self.state.policy_updated.notify_all() + + self._status.note_training_start(len(batch)) + train_call_start: float | None = None + if os.getenv("ART_TRAIN_STEP_LOG"): + print(f"[train] step {expected_step} starting (batch={len(batch)})") + train_call_start = time.perf_counter() + try: + result = await self.backend.train( + self.model, + batch, + learning_rate=self.learning_rate, + loss_fn=self.loss_fn, + loss_fn_config=self.loss_fn_config, + normalize_advantages=self.normalize_advantages, + save_checkpoint=should_checkpoint, + adam_params=self.adam_params, + ) + except Exception: + self._status.note_training_end() + raise + finally: + if train_call_start is not None: + train_call_elapsed = time.perf_counter() - train_call_start + print( + f"[train] step {expected_step} done in " + f"{train_call_elapsed:.1f}s" + ) + + try: + current_step = result.step + self.state.policy_version = current_step + self.state.next_training_step = current_step + + step_seconds = time.monotonic() - step_start + self._status.note_training_batch( + batch, step=current_step, step_seconds=step_seconds + ) + + steps_off_policy = self._average_steps_off_policy(current_step, batch) + metrics = { + "discarded_stale_samples": float( + self.state.discarded_stale_samples + ), + "steps_off_policy": steps_off_policy, + "num_groups": float(len(batch)), + } + metrics.update(result.metrics) + + await self.model.log( + batch, + split="train", + step=current_step, + metrics=metrics, + ) + await self._log_discarded_groups(current_step) + + if self.eval_fn is not None and should_eval_step: + if self._eval_queue is not None: + await self._eval_queue.put(current_step) + self.state.last_eval_step = current_step + + self._persist_state(current_step) + finally: + self._status.note_training_end() + + async with self.state.policy_updated: + self.state.policy_updated.notify_all() + + if saw_sentinel: + stop_after_batch = True + if stop_after_batch: + break + + self.state.done = True + self._persist_state(current_step) + async with self.state.policy_updated: + self.state.policy_updated.notify_all() + + async def _collect_batch( + self, current_step: int + ) -> tuple[list[TrajectoryGroup], int, bool]: + assert self._output_queue is not None + batch: list[TrajectoryGroup] = [] + discarded = 0 + saw_sentinel = False + min_version = current_step - self.max_steps_off_policy + + while len(batch) < self.min_batch_size: + item = await self._output_queue.get() + if item is None: + saw_sentinel = True + break + self._status.note_group_dequeued(item) + self._check_all_failed(item) + if self._is_group_stale(item, min_version): + discarded += 1 + continue + if self._group_zero_variance(item): + if self._record_zero_variance(item): + return [], discarded, saw_sentinel + continue + batch.append(item) + + while not saw_sentinel: + try: + item = self._output_queue.get_nowait() + except asyncio.QueueEmpty: + break + if item is None: + saw_sentinel = True + break + self._status.note_group_dequeued(item) + self._check_all_failed(item) + if self._is_group_stale(item, min_version): + discarded += 1 + continue + if self._group_zero_variance(item): + if self._record_zero_variance(item): + return [], discarded, saw_sentinel + continue + batch.append(item) + + return batch, discarded, saw_sentinel + + def _check_all_failed(self, group: TrajectoryGroup) -> None: + """Raise if all rollouts in a group failed with exceptions.""" + if not group.trajectories and group.exceptions: + first_exc = group.exceptions[0] + raise RuntimeError( + f"All {len(group.exceptions)} rollouts in group failed. " + f"First exception ({first_exc.type}): {first_exc.message}" + ) + + async def _eval_stage(self) -> None: + if self.eval_fn is None or self._eval_queue is None: + return + + pending_eval: asyncio.Task[None] | None = None + while not self.state.done or not self._eval_queue.empty(): + try: + step = await asyncio.wait_for(self._eval_queue.get(), timeout=1.0) + except asyncio.TimeoutError: + continue + + if pending_eval is not None and not pending_eval.done(): + try: + await pending_eval + except Exception as exc: + print(f"Warning: previous eval failed: {exc}") + + pending_eval = asyncio.create_task(self._run_eval(step)) + + if pending_eval is not None and not pending_eval.done(): + try: + await pending_eval + except Exception as exc: + print(f"Warning: final eval failed: {exc}") + + async def _status_loop(self) -> None: + sleep_seconds = min(1.0, max(0.2, self._status_log_interval_seconds / 10)) + while not self.state.done: + self._status.log_if_due() + await asyncio.sleep(sleep_seconds) + + async def _run_eval(self, step: int) -> None: + assert self.eval_fn is not None + self._status.note_val_started(step) + reward: float | None = None + try: + result = await self.eval_fn(self.model, step, self.config) + splits: dict[str, list[art.Trajectory]] + if isinstance(result, dict): + splits = result + else: + splits = {"val": result} + + val_trajectories = splits.get("val") or [] + if val_trajectories: + reward = sum(t.reward for t in val_trajectories) / len(val_trajectories) + else: + reward = None + + for split_name, trajectories in splits.items(): + if trajectories: + await self.model.log(trajectories, split=split_name, step=step) + except asyncio.CancelledError: + raise + except Exception as exc: + print(f"Eval failed at step {step}: {exc}") + finally: + self._status.note_val_finished(step, reward) + + def _apply_policy_versions( + self, + group: TrajectoryGroup, + *, + initial_version: int, + final_version: int, + ) -> None: + for trajectory in group.trajectories: + if trajectory.initial_policy_version is None: + trajectory.initial_policy_version = initial_version + if trajectory.final_policy_version is None: + trajectory.final_policy_version = final_version + + def _apply_scenario_metadata( + self, group: TrajectoryGroup, scenario: ScenarioT + ) -> None: + metadata = scenario.get("metadata") if isinstance(scenario, dict) else None + if metadata is None or not isinstance(metadata, dict): + return + + for key, value in metadata.items(): + if not isinstance(key, str): + continue + if not self._is_scalar_metadata(value): + continue + for trajectory in group.trajectories: + trajectory.metadata[f"scenario_{key}"] = value + + def _is_group_stale(self, group: TrajectoryGroup, min_version: int) -> bool: + group_version = self._group_initial_version(group) + if group_version is None: + return False + return group_version < min_version + + def _record_zero_variance(self, group: TrajectoryGroup) -> bool: + self._discard_queue.append(group) + self._status.note_zero_variance_discarded(1) + if len(self._discard_queue) >= self._discard_queue_limit: + self._trigger_collapse() + return True + return False + + def _trigger_collapse(self) -> None: + if self._collapse_triggered: + return + self._collapse_triggered = True + self.state.done = True + print( + "\n" + "========================================\n" + "MODEL COLLAPSE DETECTED - Training stopped\n" + "========================================\n" + "\n" + f"Too many trajectory groups ({self._discard_queue_limit}) had zero reward variance,\n" + "indicating the model may have collapsed to a degenerate policy.\n" + "\n" + "To improve training dynamics:\n" + " - Lower the learning rate to reduce instability\n" + " - Ensure your reward function provides meaningful variance\n" + " - Check that prompts are diverse enough to elicit different responses\n" + " - Consider using a smaller batch size for more frequent updates\n" + "\n" + "To disable this failsafe:\n" + " - Increase `discard_queue_multiplier` (currently triggers after\n" + f" {self.discard_queue_multiplier} * min_batch_size = {self._discard_queue_limit} zero-variance groups)\n" + "\n" + ) + + async def _log_discarded_groups(self, step: int) -> None: + if not self._discard_queue: + return + discarded = list(self._discard_queue) + await self.model.log(discarded, split="discarded", step=step) + self._discard_queue.clear() + + @staticmethod + def _group_zero_variance(group: TrajectoryGroup) -> bool: + rewards = [t.reward for t in group.trajectories] + if len(rewards) <= 1: + return True + first = rewards[0] + return all(abs(r - first) <= 1e-12 for r in rewards[1:]) + + def _group_initial_version(self, group: TrajectoryGroup) -> int | None: + versions = [ + trajectory.initial_policy_version + for trajectory in group.trajectories + if trajectory.initial_policy_version is not None + ] + if not versions: + return None + return min(versions) + + def _average_steps_off_policy( + self, current_step: int, batch: list[TrajectoryGroup] + ) -> float: + steps: list[int] = [] + for group in batch: + group_version = self._group_initial_version(group) + if group_version is None: + continue + steps.append(current_step - group_version) + if not steps: + return 0.0 + return sum(steps) / len(steps) + + def _should_eval_step(self, step: int) -> bool: + if self.eval_fn is None: + return False + if self.eval_every_n_steps <= 0: + return False + return (step - self.state.last_eval_step) >= self.eval_every_n_steps + + def _read_pipeline_state(self) -> dict[str, Any]: + state = self.model.read_state() or {} + return state.get(PIPELINE_STATE_KEY, {}) + + def _persist_state(self, training_step: int) -> None: + payload = { + "scenario_offset": self.state.scenario_offset, + "total_scenarios_consumed": self.state.total_scenarios_consumed, + "training_step": training_step, + "last_eval_step": self.state.last_eval_step, + } + self.model.merge_state({PIPELINE_STATE_KEY: payload}) + + @staticmethod + def _is_scalar_metadata(value: object) -> bool: + return value is None or isinstance(value, (str, int, float, bool)) + + async def _put_output_group(self, group: TrajectoryGroup) -> None: + assert self._output_queue is not None + while not self.state.done: + try: + await asyncio.wait_for(self._output_queue.put(group), timeout=1.0) + self._status.note_group_enqueued(group) + return + except asyncio.TimeoutError: + continue diff --git a/src/art/pipeline_trainer/types.py b/src/art/pipeline_trainer/types.py new file mode 100644 index 000000000..532acf9cd --- /dev/null +++ b/src/art/pipeline_trainer/types.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import TypeVar + +import art +from art import Trajectory, TrajectoryGroup + +ScenarioT = TypeVar("ScenarioT", bound=dict) +ConfigT = TypeVar("ConfigT") +ScalarMetadataValue = float | int | str | bool | None + + +RolloutFn = Callable[ + [art.TrainableModel, ScenarioT, ConfigT], Awaitable[TrajectoryGroup] +] + +SingleRolloutFn = Callable[ + [art.TrainableModel, ScenarioT, ConfigT], Awaitable[Trajectory] +] + +EvalFn = Callable[ + [art.TrainableModel, int, ConfigT], + Awaitable[list[Trajectory] | dict[str, list[Trajectory]]], +] diff --git a/src/art/pipeline_trainer/yes_no_maybe_pipeline.py b/src/art/pipeline_trainer/yes_no_maybe_pipeline.py new file mode 100644 index 000000000..63e3323dd --- /dev/null +++ b/src/art/pipeline_trainer/yes_no_maybe_pipeline.py @@ -0,0 +1,149 @@ +"""Minimal yes/no/maybe RL training example using PipelineTrainer.""" + +from __future__ import annotations + +import asyncio +from datetime import datetime +from functools import partial +from itertools import cycle, permutations +import os +import re + +from dotenv import load_dotenv + +import art + +from . import PipelineTrainer + +# Training config +BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" # or Qwen/Qwen3-4B-Instruct-2507 +MODEL_NAME = "pipeline-yes-no-maybe" +PROJECT = "yes-no-maybe-pipeline" +ROLLOUTS_PER_SCENARIO = 32 +MAX_TOKENS = 5 +MAX_STEPS = 20 +EVAL_TRAJECTORY_COUNT = 30 +EVAL_EVERY_N_STEPS = 2 + + +def build_scenarios() -> list[dict]: + """Generate all scenario variations.""" + scenarios: list[dict] = [] + for prefix in ["respond", "just respond"]: + for use_quotes in [True, False]: + for n in [3, 2]: + for words in permutations(["yes", "no", "maybe"], n): + quoted = [f"'{w}'" if use_quotes else w for w in words] + if len(words) == 3: + body = ", ".join(quoted) + else: + body = " or ".join(quoted) + scenarios.append({"prompt": f"{prefix} with {body}"}) + return scenarios + + +def reward_for_answer(text: str) -> float: + """Score: maybe=1.0, no=0.75, yes=0.5, other=0.0.""" + if not text: + return 0.0 + first_word = re.split(r"\s+", text.strip().lower())[0].strip(".,!?:;\"'()[]{}") + return {"maybe": 1.0, "no": 0.75, "yes": 0.5}.get(first_word, 0.0) + + +async def eval_fn( + model: art.TrainableModel, + step: int, + _config: None, + *, + scenarios: list[dict], +) -> list[art.Trajectory]: + trajectories: list[art.Trajectory] = [] + openai_client = model.openai_client() + model_name = model.get_inference_name(step) + for scenario in scenarios: + messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}] + response = await openai_client.chat.completions.create( + messages=messages, + model=model_name, + max_tokens=MAX_TOKENS, + n=1, + ) + choice = response.choices[0] + trajectories.append( + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + ) + return trajectories + + +async def rollout_fn(model, scenario, _config) -> art.TrajectoryGroup: + """Single inference call returns N completions for the group.""" + messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}] + response = await model.openai_client().chat.completions.create( + messages=messages, + model=model.get_inference_name(), + max_tokens=MAX_TOKENS, + n=ROLLOUTS_PER_SCENARIO, + ) + return art.TrajectoryGroup( + [ + art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(choice.message.content or ""), + ) + for choice in response.choices + ] + ) + + +async def main() -> None: + load_dotenv() + if not os.environ.get("TINKER_API_KEY"): + raise RuntimeError("TINKER_API_KEY environment variable is required") + + model_name = f"{MODEL_NAME}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + + print("Initializing TinkerNativeBackend") + backend = art.TinkerNativeBackend() + + print(f"Initializing TrainableModel: {model_name}") + model = art.TrainableModel(name=model_name, project=PROJECT, base_model=BASE_MODEL) + + print("Registering model with backend") + await model.register(backend) + print("Model registered") + + openai_client = model.openai_client() + base_scenarios = build_scenarios() + scenarios = cycle(base_scenarios) + eval_scenarios = base_scenarios[:EVAL_TRAJECTORY_COUNT] + + eval_callback = partial(eval_fn, scenarios=eval_scenarios) + + trainer = PipelineTrainer( + model=model, + backend=backend, + rollout_fn=rollout_fn, + scenarios=scenarios, + config=None, + learning_rate=5e-5, + loss_fn="cispo", + eval_fn=eval_callback, + max_steps=MAX_STEPS, + eval_every_n_steps=EVAL_EVERY_N_STEPS, + eval_step_0=False, + total_scenarios=None, + ) + + print( + f"Training {model_name}: {MAX_STEPS} steps, " + f"{len(base_scenarios)} unique scenarios (cycling)" + ) + await trainer.train() + await backend.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/art/tinker_native/__init__.py b/src/art/tinker_native/__init__.py new file mode 100644 index 000000000..a6dc5bc59 --- /dev/null +++ b/src/art/tinker_native/__init__.py @@ -0,0 +1,3 @@ +from .backend import TinkerNativeBackend + +__all__ = ["TinkerNativeBackend"] diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py new file mode 100644 index 000000000..291621b6c --- /dev/null +++ b/src/art/tinker_native/backend.py @@ -0,0 +1,766 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +import os +import re +import time +from typing import Any, Awaitable, Iterable, Literal, TypeVar, cast +import uuid + +from fastapi import FastAPI, HTTPException +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_message_function_tool_call import ( + ChatCompletionMessageFunctionToolCall, + Function, +) +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCallUnion, +) +from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob +from openai.types.chat.completion_create_params import CompletionCreateParams +from openai.types.completion_usage import CompletionUsage +import tinker +from tinker_cookbook import renderers, tokenizer_utils +import uvicorn + +from .. import dev +from ..backend import Backend +from ..model import Model, TrainableModel +from ..tinker.backend import get_renderer_name +from ..tinker.server import get_free_port +from ..trajectories import TrajectoryGroup +from ..types import TrainResult +from ..utils.output_dirs import get_model_dir +from ..utils.trajectory_migration import auto_migrate_on_register +from .data import ( + convert_openai_messages_to_renderer_format, + parse_completion_to_openai_message, + trajectory_groups_to_datums, +) + +STATE_KEY_RUN_IDS = "tinker_run_ids" +STATE_KEY_LATEST_STEP = "latest_step" +T = TypeVar("T") + + +@dataclass +class ModelState: + service_client: tinker.ServiceClient + rest_client: Any + training_client: tinker.TrainingClient + sampler_clients: dict[int, tinker.SamplingClient] + sampler_checkpoint_paths: dict[int, str] + training_checkpoint_paths: dict[int, str] + current_step: int + renderer: Any + tokenizer: Any + output_dir: str + tinker_run_ids: list[str] + model_name: str + server_task: asyncio.Task[None] | None = None + server_host: str | None = None + server_port: int | None = None + server_api_key: str | None = None + + +@dataclass +class TinkerNativeModelConfig: + renderer_name: str + training_client_args: dict[str, Any] + + +class TinkerNativeBackend(Backend): + _tinker_train_log_env = "ART_TINKER_TRAIN_LOG" + _tinker_sample_log_env = "ART_TINKER_SAMPLE_LOG" + + def __init__( + self, + *, + tinker_api_key: str | None = None, + path: str | None = None, + ) -> None: + if not "TINKER_API_KEY" in os.environ or tinker_api_key is not None: + assert tinker_api_key is not None, ( + "TINKER_API_KEY is not set and no tinker_api_key was provided" + ) + print("Setting TINKER_API_KEY to", tinker_api_key, "in environment") + os.environ["TINKER_API_KEY"] = tinker_api_key + + self._path = path or ".art" + os.makedirs(self._path, exist_ok=True) + self._model_state: dict[str, ModelState] = {} + + def _env_enabled(self, env_name: str) -> bool: + value = os.getenv(env_name) + if value is None: + return False + return value.lower() not in ("", "0", "false", "no") + + def _timestamp(self) -> str: + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + async def _tinker_call( + self, + label: str, + awaitable: Awaitable[T], + *, + env_name: str, + prefix: str, + ) -> T: + if not self._env_enabled(env_name): + return await awaitable + start = time.perf_counter() + print(f"[tinker:{prefix}] {label} start {self._timestamp()}") + try: + return await awaitable + finally: + elapsed = time.perf_counter() - start + print( + f"[tinker:{prefix}] {label} done in {elapsed:.2f}s " + f"at {self._timestamp()}" + ) + + async def _tinker_train_call(self, label: str, awaitable: Awaitable[T]) -> T: + return await self._tinker_call( + label, + awaitable, + env_name=self._tinker_train_log_env, + prefix="train", + ) + + async def _tinker_sample_call(self, label: str, awaitable: Awaitable[T]) -> T: + return await self._tinker_call( + label, + awaitable, + env_name=self._tinker_sample_log_env, + prefix="sample", + ) + + async def close(self) -> None: + for state in self._model_state.values(): + if state.server_task is not None: + state.server_task.cancel() + + async def register(self, model: Model) -> None: + model.base_path = self._path + output_dir = get_model_dir(model=model, art_path=self._path) + os.makedirs(output_dir, exist_ok=True) + with open(f"{output_dir}/model.json", "w") as f: + import json + + json.dump(model.model_dump(), f) + + auto_migrate_on_register(output_dir) + + if not model.trainable: + return + trainable_model = cast(TrainableModel, model) + state = await self._build_model_state(trainable_model) + self._model_state[model.name] = state + + async def _prepare_backend_for_training( + self, + model: TrainableModel, + config: dev.OpenAIServerConfig | None = None, + ) -> tuple[str, str]: + state = self._model_state[model.name] + + raw_config: dict[str, Any] = cast(dict[str, Any], config) if config else {} + server_args = cast(dict[str, Any], raw_config.get("server_args", {})) + host = server_args.get("host", raw_config.get("host", "0.0.0.0")) + port = server_args.get("port", raw_config.get("port")) + if port is None: + port = get_free_port() + api_key = server_args.get("api_key", raw_config.get("api_key")) or "default" + + if state.server_task is None: + state.server_host = host + state.server_port = port + state.server_api_key = api_key + state.server_task = asyncio.create_task( + self._run_openai_server(state, host=host, port=port) + ) + state.server_task.add_done_callback(self._crash_on_server_exit) + + base_url = f"http://{host}:{port}/v1" + await self._wait_for_server_ready(base_url, api_key, model) + return base_url, api_key + + async def train( # type: ignore[override] + self, + model: TrainableModel, + trajectory_groups: Iterable[TrajectoryGroup], + *, + learning_rate: float = 1e-5, + loss_fn: Literal["cispo", "ppo", "importance_sampling", "dro"] = "cispo", + normalize_advantages: bool = True, + save_checkpoint: bool = False, + loss_fn_config: dict | None = None, + adam_params: tinker.AdamParams | None = None, + ) -> TrainResult: + state = self._model_state[model.name] + groups_list = list(trajectory_groups) + + datums = trajectory_groups_to_datums( + groups_list, + state.renderer, + state.tokenizer, + normalize_advantages, + ) + + metrics: dict[str, float] = { + "num_groups_submitted": float(len(groups_list)), + "num_datums": float(len(datums)), + } + + if not datums: + return TrainResult(step=state.current_step, metrics=metrics) + + if adam_params is None: + adam_params = tinker.AdamParams( + learning_rate=learning_rate, + beta1=0.9, + beta2=0.95, + eps=1e-8, + ) + + def remove_mask(datum: tinker.Datum) -> tinker.Datum: + if "mask" not in datum.loss_fn_inputs: + return datum + loss_fn_inputs = { + key: value + for key, value in datum.loss_fn_inputs.items() + if key != "mask" + } + return tinker.Datum( + model_input=datum.model_input, loss_fn_inputs=loss_fn_inputs + ) + + forward_output = await self._tinker_train_call( + "forward_backward", + state.training_client.forward_backward( + [remove_mask(datum) for datum in datums], + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ), + ) + optim_output = await self._tinker_train_call( + "optim_step", state.training_client.optim_step(adam_params) + ) + + if forward_output.metrics: + for key, value in forward_output.metrics.items(): + if value is None: + continue + metrics[key] = float(value) + if optim_output.metrics: + for key, value in optim_output.metrics.items(): + if value is None: + continue + metrics[key] = float(value) + + next_step = state.current_step + 1 + checkpoint_name = f"step_{next_step:06d}" + + if save_checkpoint: + state_response, sampler_response = await self._save_checkpoint( + state.training_client, checkpoint_name + ) + state.training_checkpoint_paths[next_step] = state_response.path + else: + sampler_response = await self._save_sampler_weights( + state.training_client, checkpoint_name + ) + sampler_client = await self._tinker_train_call( + "create_sampling_client_async", + state.training_client.create_sampling_client_async( + model_path=sampler_response.path + ), + ) + state.sampler_clients[next_step] = sampler_client + state.sampler_checkpoint_paths[next_step] = sampler_response.path + + state.current_step = next_step + self._persist_model_state(model, state) + + return TrainResult(step=state.current_step, metrics=metrics) + + async def _get_step(self, model: TrainableModel) -> int: + if model.name in self._model_state: + return self._model_state[model.name].current_step + state = model.read_state() or {} + return int(state.get(STATE_KEY_LATEST_STEP, 0)) + + async def _delete_checkpoint_files( + self, + model: TrainableModel, + steps_to_keep: list[int], + ) -> None: + print("Checkpoint deletion is not yet implemented for TinkerNativeBackend.") + + def _model_inference_name(self, model: Model, step: int | None = None) -> str: + base_name = model.inference_model_name or model.name + if "@" in base_name: + base_name = base_name.split("@", 1)[0] + if step is None: + state = self._model_state.get(model.name) + step = state.current_step if state is not None else 0 + return f"{base_name}@{step}" + + async def _run_openai_server( + self, + state: ModelState, + host: str, + port: int, + ) -> None: + app = FastAPI() + + @app.post("/v1/chat/completions") + async def chat_completions(body: CompletionCreateParams) -> ChatCompletion: + model_name = body.get("model") + _, step = self._parse_model_name(model_name) + sampler_client = await self._get_sampler_client(state, step) + + messages = self._normalize_messages(body["messages"]) + tools = self._normalize_tools(body.get("tools")) + renderer_messages = convert_openai_messages_to_renderer_format( + messages=messages, + tools=tools, + renderer=state.renderer, + ) + prompt = state.renderer.build_generation_prompt(renderer_messages) + prompt_tokens = list(prompt.to_ints()) + + max_tokens = body.get("max_completion_tokens") + if max_tokens is None: + max_tokens = body.get("max_tokens") + temperature = body.get("temperature") + top_k = body.get("top_k") + top_p = body.get("top_p") + sampling_params = tinker.SamplingParams( + max_tokens=max_tokens, + seed=body.get("seed"), + temperature=temperature if temperature is not None else 1.0, + top_k=top_k if top_k is not None else -1, + top_p=top_p if top_p is not None else 1.0, + stop=state.renderer.get_stop_sequences(), + ) + sample_response = await self._tinker_sample_call( + "sample_async", + sampler_client.sample_async( + prompt=prompt, + num_samples=body.get("n") or 1, + sampling_params=sampling_params, + ), + ) + + choices: list[Choice] = [] + for i, sequence in enumerate(sample_response.sequences): + if sequence.logprobs is None: + raise HTTPException(status_code=400, detail="Logprobs are required") + if len(sequence.tokens) != len(sequence.logprobs): + raise HTTPException( + status_code=400, + detail="Tokens and logprobs must have the same length", + ) + parsed_message = parse_completion_to_openai_message( + list(sequence.tokens), state.renderer + ) + tool_calls: list[ChatCompletionMessageToolCallUnion] | None = None + if parsed_message.get("tool_calls"): + tool_calls = [ + ChatCompletionMessageFunctionToolCall( + type="function", + id=tool_call["id"], + function=Function( + name=tool_call["function"]["name"], + arguments=tool_call["function"]["arguments"], + ), + ) + for tool_call in parsed_message["tool_calls"] + ] + choices.append( + Choice( + finish_reason=sequence.stop_reason, + index=i, + message=ChatCompletionMessage( + content=parsed_message.get("content", ""), + role="assistant", + tool_calls=tool_calls, + ), + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token=f"token_id:{token}", + logprob=logprob, + top_logprobs=[], + ) + for token, logprob in zip( + sequence.tokens, sequence.logprobs + ) + ] + ), + ) + ) + + completion_tokens = sum( + len(sequence.tokens) for sequence in sample_response.sequences + ) + return ChatCompletion( + id=str(uuid.uuid4()), + choices=choices, + created=int(time.time()), + model=self._format_response_model(model_name, step, state), + object="chat.completion", + usage=CompletionUsage( + completion_tokens=completion_tokens, + prompt_tokens=len(prompt_tokens), + total_tokens=completion_tokens + len(prompt_tokens), + ), + ) + + server_config = uvicorn.Config(app, host=host, port=port, log_level="error") + server = uvicorn.Server(server_config) + await server.serve() + + def _crash_on_server_exit(self, task: asyncio.Task[None]) -> None: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + print(f"OpenAI server crashed: {exc}") + else: + print("OpenAI server exited unexpectedly.") + os._exit(1) + + async def _wait_for_server_ready( + self, base_url: str, api_key: str, model: TrainableModel + ) -> None: + client = AsyncOpenAI(base_url=base_url, api_key=api_key) + with_timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 300.0)) + start = time.time() + while True: + if time.time() - start > with_timeout: + raise TimeoutError( + f"Unable to reach OpenAI-compatible server within {with_timeout} seconds." + ) + try: + await client.chat.completions.create( + model=self._model_inference_name(model), + messages=[{"role": "user", "content": "Hello, world!"}], + max_completion_tokens=1, + ) + return + except Exception: + await asyncio.sleep(0.1) + + async def _build_model_state(self, model: TrainableModel) -> ModelState: + config = self._resolve_model_config(model) + service_client = tinker.ServiceClient() + rest_client = service_client.create_rest_client() + + tokenizer = tokenizer_utils.get_tokenizer(model.base_model) + renderer = renderers.get_renderer( + name=config.renderer_name, + tokenizer=tokenizer, + ) + + saved_state = model.read_state() or {} + tinker_run_ids = list(saved_state.get(STATE_KEY_RUN_IDS, [])) + training_paths, sampler_paths = await self._list_checkpoints( + rest_client, tinker_run_ids + ) + + training_client: tinker.TrainingClient + current_step = 0 + + if training_paths: + current_step = max(training_paths.keys()) + checkpoint_path = training_paths[current_step] + training_client = await self._create_training_client_from_checkpoint( + service_client=service_client, + checkpoint_state_path=checkpoint_path, + base_model=model.base_model, + training_client_args=config.training_client_args, + reset_optimizer=False, + ) + else: + training_client = await self._tinker_train_call( + "create_lora_training_client_async", + service_client.create_lora_training_client_async( + model.base_model, **config.training_client_args + ), + ) + checkpoint_name = f"step_{current_step:06d}" + sampler_response = await self._save_sampler_weights( + training_client, checkpoint_name + ) + sampler_paths[current_step] = sampler_response.path + + run_id = training_client.model_id + if run_id not in tinker_run_ids: + tinker_run_ids.append(run_id) + + sampler_clients: dict[int, tinker.SamplingClient] = {} + if current_step in sampler_paths: + sampler_clients[current_step] = await self._tinker_train_call( + "create_sampling_client_async", + training_client.create_sampling_client_async( + model_path=sampler_paths[current_step] + ), + ) + else: + checkpoint_name = f"step_{current_step:06d}" + sampler_response = await self._save_sampler_weights( + training_client, checkpoint_name + ) + sampler_paths[current_step] = sampler_response.path + sampler_clients[current_step] = await self._tinker_train_call( + "create_sampling_client_async", + training_client.create_sampling_client_async( + model_path=sampler_response.path + ), + ) + + state = ModelState( + service_client=service_client, + rest_client=rest_client, + training_client=training_client, + sampler_clients=sampler_clients, + sampler_checkpoint_paths=sampler_paths, + training_checkpoint_paths=training_paths, + current_step=current_step, + renderer=renderer, + tokenizer=tokenizer, + output_dir=get_model_dir(model=model, art_path=self._path), + tinker_run_ids=tinker_run_ids, + model_name=((model.inference_model_name or model.name).split("@", 1)[0]), + ) + self._persist_model_state(model, state) + return state + + def _resolve_model_config(self, model: TrainableModel) -> TinkerNativeModelConfig: + internal_config = model._internal_config or {} + tinker_native_args = cast( + dev.TinkerNativeArgs | None, + internal_config.get("tinker_native_args"), + ) + renderer_name = ( + tinker_native_args.get("renderer_name") + if tinker_native_args is not None + else None + ) + if renderer_name is None: + renderer_name = get_renderer_name(model.base_model) + + training_client_args = dict( + tinker_native_args.get("training_client_args", {}) + if tinker_native_args is not None + else {} + ) + if "rank" not in training_client_args: + training_client_args["rank"] = 8 + if "train_unembed" not in training_client_args: + training_client_args["train_unembed"] = False + + return TinkerNativeModelConfig( + renderer_name=renderer_name, + training_client_args=training_client_args, + ) + + async def _list_checkpoints( + self, rest_client: Any, tinker_run_ids: list[str] + ) -> tuple[dict[int, str], dict[int, str]]: + training_paths: dict[int, str] = {} + sampler_paths: dict[int, str] = {} + step_pattern = re.compile(r"(?:weights/)?step_(\d+)$") + + for run_id in tinker_run_ids: + try: + response = await self._tinker_train_call( + f"list_checkpoints_async {run_id}", + rest_client.list_checkpoints_async(run_id), + ) + except Exception as exc: + print(f"Warning: Could not list checkpoints for {run_id}: {exc}") + continue + for checkpoint in response.checkpoints: + match = step_pattern.match(checkpoint.checkpoint_id) + if not match: + continue + step = int(match.group(1)) + path = f"tinker://{run_id}/{checkpoint.checkpoint_id}" + if checkpoint.checkpoint_type == "training": + training_paths[step] = path + elif checkpoint.checkpoint_type == "sampler": + sampler_paths[step] = path + return training_paths, sampler_paths + + async def _get_sampler_client( + self, state: ModelState, step: int | None + ) -> tinker.SamplingClient: + actual_step = step if step is not None else state.current_step + if actual_step in state.sampler_clients: + return state.sampler_clients[actual_step] + + if actual_step not in state.sampler_checkpoint_paths: + training_paths, sampler_paths = await self._list_checkpoints( + state.rest_client, state.tinker_run_ids + ) + state.training_checkpoint_paths.update(training_paths) + state.sampler_checkpoint_paths.update(sampler_paths) + + if actual_step not in state.sampler_checkpoint_paths: + available = sorted(state.sampler_checkpoint_paths.keys()) + raise HTTPException( + status_code=404, + detail=f"No sampler checkpoint for step {actual_step}. Available: {available}", + ) + + sampler_client = await self._tinker_train_call( + "create_sampling_client_async", + state.training_client.create_sampling_client_async( + model_path=state.sampler_checkpoint_paths[actual_step] + ), + ) + state.sampler_clients[actual_step] = sampler_client + return sampler_client + + def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for message in messages: + if hasattr(message, "model_dump"): + normalized.append(message.model_dump()) + else: + normalized.append(dict(message)) + return normalized + + def _normalize_tools( + self, tools: Iterable[Any] | None + ) -> list[dict[str, Any]] | None: + if tools is None: + return None + normalized: list[dict[str, Any]] = [] + for tool in tools: + if hasattr(tool, "model_dump"): + normalized.append(tool.model_dump()) + else: + normalized.append(dict(tool)) + return normalized + + def _parse_model_name( + self, model_name: str | None + ) -> tuple[str | None, int | None]: + if model_name and "@" in model_name: + base_name, step_str = model_name.rsplit("@", 1) + try: + return base_name, int(step_str) + except ValueError as exc: + raise HTTPException( + status_code=400, detail=f"Invalid model step: {model_name}" + ) from exc + return model_name, None + + def _format_response_model( + self, model_name: str | None, step: int | None, state: ModelState + ) -> str: + if model_name is None: + return f"{state.model_name}@{state.current_step}" + if step is None and "@" not in model_name: + return f"{model_name}@{state.current_step}" + return model_name + + async def _create_training_client_from_checkpoint( + self, + service_client: tinker.ServiceClient, + checkpoint_state_path: str, + base_model: str, + training_client_args: dict[str, Any], + reset_optimizer: bool = False, + ) -> tinker.TrainingClient: + training_client = await self._tinker_train_call( + "create_lora_training_client_async", + service_client.create_lora_training_client_async( + base_model, **training_client_args + ), + ) + + if reset_optimizer: + load_future = await self._tinker_train_call( + "load_state_async", + training_client.load_state_async(checkpoint_state_path), + ) + await self._tinker_train_call( + "load_state_result_async", load_future.result_async() + ) + else: + load_future = await self._tinker_train_call( + "load_state_with_optimizer_async", + training_client.load_state_with_optimizer_async(checkpoint_state_path), + ) + await self._tinker_train_call( + "load_state_with_optimizer_result_async", load_future.result_async() + ) + + return training_client + + async def _save_checkpoint( + self, + training_client: tinker.TrainingClient, + checkpoint_name: str, + ) -> tuple[Any, Any]: + state_future, sampler_future = await asyncio.gather( + self._tinker_train_call( + "save_state_async", + training_client.save_state_async(checkpoint_name), + ), + self._tinker_train_call( + "save_weights_for_sampler_async", + training_client.save_weights_for_sampler_async(checkpoint_name), + ), + ) + state_result = await self._tinker_train_call( + "save_state_result_async", state_future.result_async() + ) + sampler_result = await self._tinker_train_call( + "save_weights_for_sampler_result_async", sampler_future.result_async() + ) + return state_result, sampler_result + + async def _save_sampler_weights( + self, + training_client: tinker.TrainingClient, + checkpoint_name: str, + ) -> Any: + sampler_future = await self._tinker_train_call( + "save_weights_for_sampler_async", + training_client.save_weights_for_sampler_async(checkpoint_name), + ) + return await self._tinker_train_call( + "save_weights_for_sampler_result_async", sampler_future.result_async() + ) + + async def _save_training_state( + self, + training_client: tinker.TrainingClient, + checkpoint_name: str, + ) -> Any: + state_future = await self._tinker_train_call( + "save_state_async", + training_client.save_state_async(checkpoint_name), + ) + return await self._tinker_train_call( + "save_state_result_async", state_future.result_async() + ) + + def _persist_model_state(self, model: TrainableModel, state: ModelState) -> None: + model.merge_state( + { + STATE_KEY_RUN_IDS: state.tinker_run_ids, + STATE_KEY_LATEST_STEP: state.current_step, + } + ) diff --git a/src/art/tinker_native/data.py b/src/art/tinker_native/data.py new file mode 100644 index 000000000..994d6e39a --- /dev/null +++ b/src/art/tinker_native/data.py @@ -0,0 +1,382 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Iterable, cast + +from openai.types.chat.chat_completion import Choice +import tinker +from tinker_cookbook import renderers +import torch + +from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages +from ..types import MessagesAndChoices + + +def _create_conversation_prefix_with_tools_fallback( + tools: list[dict[str, Any]], system_prompt: str = "" +) -> list[dict[str, Any]]: + """Fallback implementation for create_conversation_prefix_with_tools. + + Used when the installed tinker_cookbook version doesn't have this method. + Implements the Qwen3 tool format. + """ + tools_text = "" + if tools: + # Each tool is wrapped in {"type": "function", "function": {...}} per OpenAI format + tool_lines = "\n".join( + json.dumps({"type": "function", "function": tool}, separators=(", ", ": ")) + for tool in tools + ) + tools_text = f"""# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{tool_lines} + + +For each function call, return a json object with function name and arguments within XML tags: + +{{"name": , "arguments": }} +""" + + # Add separator between system prompt and tools if system prompt exists + if system_prompt: + content = system_prompt + "\n\n" + tools_text + else: + content = tools_text + + return [{"role": "system", "content": content}] + + +def create_conversation_prefix_with_tools( + renderer: Any, tools: list[dict[str, Any]], system_prompt: str = "" +) -> list[dict[str, Any]]: + """Create conversation prefix with tools, using renderer method or fallback.""" + if hasattr(renderer, "create_conversation_prefix_with_tools"): + return renderer.create_conversation_prefix_with_tools(tools, system_prompt) + return _create_conversation_prefix_with_tools_fallback(tools, system_prompt) + + +def compute_advantages( + rewards: list[float], normalize_advantages: bool = True +) -> list[float]: + if not rewards: + return [] + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + centered = rewards_tensor - rewards_tensor.mean() + if not normalize_advantages: + return centered.tolist() + std_reward = rewards_tensor.std() + if std_reward > 1e-8: + return (centered / std_reward).tolist() + return [0.0] * len(rewards) + + +def convert_openai_messages_to_renderer_format( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + renderer: Any, +) -> list[dict[str, Any]]: + if tools and len(messages) > 0 and messages[0].get("role") == "system": + original_system = messages[0].get("content", "") + + tool_specs = [] + for tool in tools: + if tool.get("type") == "function": + func = tool.get("function", {}) + tool_specs.append(func) + else: + tool_specs.append(tool) + + tool_messages = create_conversation_prefix_with_tools( + renderer, tool_specs, system_prompt=original_system + ) + + converted = list(tool_messages) + messages = messages[1:] + else: + converted = [] + + for msg in messages: + role = msg.get("role") + content = msg.get("content", "") + + if role == "system": + converted.append({"role": "system", "content": content}) + + elif role == "user": + converted.append({"role": "user", "content": content}) + + elif role == "assistant": + assistant_msg: dict[str, Any] = { + "role": "assistant", + "content": content or "", + } + + if "tool_calls" in msg and msg["tool_calls"]: + tool_calls = [] + for tool_call in msg["tool_calls"]: + func = tool_call.get("function", {}) + tool_calls.append( + renderers.ToolCall( + id=tool_call.get("id", ""), + function=renderers.ToolCall.FunctionBody( + name=func.get("name", ""), + arguments=func.get("arguments", "{}"), + ), + ) + ) + assistant_msg["tool_calls"] = tool_calls + + converted.append(assistant_msg) + + elif role == "tool": + converted.append( + { + "role": "tool", + "content": content, + "tool_call_id": msg.get("tool_call_id", ""), + "name": msg.get("name", ""), + } + ) + + return converted + + +def _extract_gpt_oss_tool_calls(content: str) -> tuple[str, list[dict[str, Any]]]: + tool_calls = [] + cleaned_content = content + + pattern = r"(\{[^}]*\})(?:<\|call\|>)?" + + matches = list(re.finditer(pattern, content)) + for i, match in enumerate(matches): + func_name = match.group(1) + args_json = match.group(2) + + tool_calls.append( + { + "id": f"call_{i}", + "type": "function", + "function": { + "name": func_name, + "arguments": args_json, + }, + } + ) + + cleaned_content = cleaned_content.replace(match.group(0), "").strip() + + return cleaned_content, tool_calls + + +def parse_completion_to_openai_message( + completion_tokens: list[int], + renderer: Any, +) -> dict[str, Any]: + message, _ = renderer.parse_response(completion_tokens) + + result: dict[str, Any] = {"role": "assistant"} + + content = message.get("content", "") + if isinstance(content, str): + result["content"] = content + else: + text_parts = [] + for part in content: + if part["type"] == "text": + text_parts.append(part["text"]) + elif part["type"] == "thinking": + text_parts.append(part["thinking"]) + result["content"] = "".join(text_parts) + + if "tool_calls" in message and message["tool_calls"]: + result["tool_calls"] = [ + { + "id": tool_call.id or f"call_{i}", + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + for i, tool_call in enumerate(message["tool_calls"]) + ] + else: + if result.get("content") and " bool: + for message_or_choice in trajectory.messages_and_choices: + if isinstance(message_or_choice, Choice): + return True + for history in trajectory.additional_histories: + for message_or_choice in history.messages_and_choices: + if isinstance(message_or_choice, Choice): + return True + return False + + +def trajectory_groups_to_datums( + trajectory_groups: Iterable[TrajectoryGroup], + renderer: Any, + tokenizer: Any, + normalize_advantages: bool = True, +) -> list[tinker.Datum]: + datums: list[tinker.Datum] = [] + + for group in trajectory_groups: + if not group.trajectories: + continue + for trajectory in group.trajectories: + if not _trajectory_has_choice(trajectory): + raise ValueError( + "Trajectory is missing a Choice object. Training requires at least one Choice " + "to compute logprobs. Ensure your rollout includes an OpenAI Choice in " + "Trajectory.messages_and_choices." + ) + rewards = [trajectory.reward for trajectory in group.trajectories] + advantages = compute_advantages(rewards, normalize_advantages) + + if all(advantage == 0.0 for advantage in advantages): + continue + for trajectory, advantage in zip(group.trajectories, advantages): + for history in iter_trajectory_histories(trajectory): + datum = history_to_datum(history, advantage, renderer, tokenizer) + if datum is not None: + datums.append(datum) + + return datums + + +def iter_trajectory_histories(trajectory: Trajectory) -> Iterable[History]: + yield History( + messages_and_choices=trajectory.messages_and_choices, + tools=trajectory.tools, + ) + yield from trajectory.additional_histories + + +def find_last_choice( + messages_and_choices: MessagesAndChoices, +) -> tuple[int, Choice] | None: + for idx in range(len(messages_and_choices) - 1, -1, -1): + message = messages_and_choices[idx] + if isinstance(message, Choice): + return idx, message + return None + + +def extract_logprobs_from_choice( + choice: Choice, tokenizer: Any +) -> tuple[list[int], list[float]]: + if choice.logprobs is None: + return [], [] + token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] + tokens: list[int] = [] + logprobs: list[float] = [] + for token_logprob in token_logprobs: + token_str = token_logprob.token or "" + if token_str.startswith("token_id:"): + try: + token_id = int(token_str.split(":")[1]) + except ValueError: + continue + tokens.append(token_id) + logprobs.append(token_logprob.logprob) + else: + token_id = tokenizer.convert_tokens_to_ids(token_str) + if token_id is None: + continue + tokens.append(int(token_id)) + logprobs.append(token_logprob.logprob) + return tokens, logprobs + + +def history_to_datum( + history: History, + advantage: float, + renderer: Any, + tokenizer: Any, +) -> tinker.Datum | None: + choice_info = find_last_choice(history.messages_and_choices) + if choice_info is None: + return None + choice_index, choice = choice_info + + completion_tokens, logprobs = extract_logprobs_from_choice(choice, tokenizer) + if not completion_tokens or len(completion_tokens) != len(logprobs): + return None + + prompt_messages = cast( + list[dict[str, Any]], get_messages(history.messages_and_choices[:choice_index]) + ) + renderer_messages = convert_openai_messages_to_renderer_format( + messages=prompt_messages, + tools=cast(list[dict[str, Any]] | None, history.tools), + renderer=renderer, + ) + prompt_input = renderer.build_generation_prompt(renderer_messages) + prompt_tokens = list(prompt_input.to_ints()) + + return build_datum( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + logprobs=logprobs, + advantage=advantage, + ) + + +def build_datum( + prompt_tokens: list[int], + completion_tokens: list[int], + logprobs: list[float], + advantage: float, +) -> tinker.Datum | None: + if not prompt_tokens or not completion_tokens: + return None + ob_len = max(len(prompt_tokens) - 1, 0) + + all_tokens = prompt_tokens + completion_tokens + input_tokens = all_tokens[:-1] + target_tokens = all_tokens[1:] + + padded_logprobs = [0.0] * ob_len + list(logprobs) + padded_advantages = [0.0] * ob_len + [advantage] * len(completion_tokens) + action_mask = [0.0] * ob_len + [1.0] * len(completion_tokens) + + if not ( + len(input_tokens) + == len(target_tokens) + == len(padded_logprobs) + == len(padded_advantages) + == len(action_mask) + ): + return None + + return tinker.Datum( + model_input=tinker.ModelInput.from_ints(tokens=input_tokens), + loss_fn_inputs={ + "target_tokens": tinker.TensorData.from_torch(torch.tensor(target_tokens)), + "logprobs": tinker.TensorData.from_torch( + torch.tensor(padded_logprobs, dtype=torch.float32) + ), + "advantages": tinker.TensorData.from_torch( + torch.tensor(padded_advantages, dtype=torch.float32) + ), + "mask": tinker.TensorData.from_torch( + torch.tensor(action_mask, dtype=torch.float32) + ), + }, + ) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index f74fb3039..a04762463 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -40,6 +40,8 @@ class Trajectory(pydantic.BaseModel): tools: Tools | None = None additional_histories: list[History] = [] reward: float + initial_policy_version: int | None = None + final_policy_version: int | None = None metrics: dict[str, float | int | bool] = {} auto_metrics: dict[str, float | int | bool] = {} metadata: dict[str, MetadataValue] = {} @@ -78,6 +80,8 @@ def messages(self) -> Messages: def for_logging(self) -> dict[str, Any]: loggable_dict = { "reward": self.reward, + "initial_policy_version": self.initial_policy_version, + "final_policy_version": self.final_policy_version, "metrics": self.metrics, "metadata": self.metadata, "messages": [], diff --git a/tests/integration/test_multi_checkpoint_training.py b/tests/integration/test_multi_checkpoint_training.py index 9252f59a4..5d74da0e4 100644 --- a/tests/integration/test_multi_checkpoint_training.py +++ b/tests/integration/test_multi_checkpoint_training.py @@ -61,7 +61,7 @@ async def simple_rollout( async def run_training_loop( model: art.TrainableModel, - backend: Union[LocalBackend, art.ServerlessBackend, art.TinkerBackend], + backend: art.Backend, num_steps: int = 1, rollouts_per_step: int = 4, ) -> list[TrainResult]: diff --git a/tests/integration/test_tinker_native_backend.py b/tests/integration/test_tinker_native_backend.py new file mode 100644 index 000000000..bbaa72729 --- /dev/null +++ b/tests/integration/test_tinker_native_backend.py @@ -0,0 +1,114 @@ +"""Integration test for TinkerNativeBackend based on yes-no-maybe.""" + +import os +import tempfile +import uuid + +import openai +import pytest + +import art + +DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" + + +def get_base_model() -> str: + return os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL) + + +def ensure_reward_variance(groups) -> None: + for group in groups: + rewards = [t.reward for t in group] + if len(rewards) < 2: + continue + if len(set(rewards)) <= 1: + group.trajectories[0].reward = 1.0 + group.trajectories[1].reward = 0.0 + + +async def simple_rollout( + client: openai.AsyncOpenAI, model_name: str, prompt: str +) -> art.Trajectory: + messages: art.Messages = [{"role": "user", "content": prompt}] + chat_completion = await client.chat.completions.create( + messages=messages, + model=model_name, + max_tokens=10, + timeout=60, + temperature=1, + ) + choice = chat_completion.choices[0] + content = (choice.message.content or "").lower() + if "yes" in content: + reward = 1.0 + elif "no" in content: + reward = 0.5 + elif "maybe" in content: + reward = 0.25 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) # type: ignore[attr-defined] + + +@pytest.mark.skipif( + "TINKER_API_KEY" not in os.environ, + reason="TINKER_API_KEY not set - skipping TinkerNativeBackend test", +) +async def test_tinker_native_backend(): + model_name = f"test-tinker-native-{uuid.uuid4().hex[:8]}" + with tempfile.TemporaryDirectory() as tmpdir: + backend = art.TinkerNativeBackend(path=tmpdir) # type: ignore[attr-defined] + model = art.TrainableModel( # type: ignore[attr-defined] + name=model_name, + project="integration-tests", + base_model=get_base_model(), + ) + try: + await model.register(backend) + + openai_client = model.openai_client() + current_step = await model.get_step() + model_name_step = model.get_inference_name(step=current_step) + prompts = ["Say yes", "Say no", "Say maybe"] + + async def make_group(prompt: str) -> art.TrajectoryGroup: + import asyncio + + trajectories = await asyncio.gather( + *[ + simple_rollout(openai_client, model_name_step, prompt) + for _ in range(2) + ] + ) + return art.TrajectoryGroup(trajectories) # type: ignore[attr-defined] + + train_groups = await art.gather_trajectory_groups( # type: ignore[attr-defined] + [make_group(prompt) for prompt in prompts] + ) + ensure_reward_variance(train_groups) + + result = await backend.train( + model, + train_groups, + learning_rate=1e-5, + ) + await model.log( + train_groups, metrics=result.metrics, step=result.step, split="train" + ) + + assert result.step > current_step + + await openai_client.chat.completions.create( + messages=[{"role": "user", "content": "Say hello"}], + model=model.get_inference_name(step=result.step), + max_tokens=10, + timeout=30, + ) + await openai_client.chat.completions.create( + messages=[{"role": "user", "content": "Say hello"}], + model=model.get_inference_name(step=0), + max_tokens=10, + timeout=30, + ) + finally: + await backend.close() diff --git a/uv.lock b/uv.lock index 1e73e5e41..fcc458ce5 100644 --- a/uv.lock +++ b/uv.lock @@ -314,9 +314,12 @@ wheels = [ [[package]] name = "antlr4-python3-runtime" -version = "4.9.3" +version = "4.13.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" } +sdist = { url = "https://files.pythonhosted.org/packages/33/5f/2cdf6f7aca3b20d3f316e9f505292e1f256a32089bd702034c29ebde6242/antlr4_python3_runtime-4.13.2.tar.gz", hash = "sha256:909b647e1d2fc2b70180ac586df3933e38919c85f98ccc656a96cd3f25ef3916", size = 117467, upload-time = "2024-08-03T19:00:12.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/03/a851e84fcbb85214dc637b6378121ef9a0dd61b4c65264675d8a5c9b1ae7/antlr4_python3_runtime-4.13.2-py3-none-any.whl", hash = "sha256:fe3835eb8d33daece0e799090eda89719dbccee7aa39ef94eed3818cafa5a7e8", size = 144462, upload-time = "2024-08-03T19:00:11.134Z" }, +] [[package]] name = "anyio" @@ -2918,7 +2921,7 @@ wheels = [ [[package]] name = "inspect-ai" -version = "0.3.158" +version = "0.3.163" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aioboto3" }, @@ -2953,13 +2956,14 @@ dependencies = [ { name = "sniffio" }, { name = "tenacity" }, { name = "textual" }, + { name = "tiktoken" }, { name = "typing-extensions" }, { name = "universal-pathlib" }, { name = "zipp" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/ed/39847f3251ad4cc78553be3ebe458cf9962614ea17badaef0578e27d0221/inspect_ai-0.3.158.tar.gz", hash = "sha256:08b8025341c5815075e2f3fdcb2313074b0dc7a22d6ac8ee48306d7bc92700e0", size = 43435860, upload-time = "2025-12-24T12:03:47.24Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/1f/ceaff3a92c03196cc2503a1ec8cc865ca4695a7e25a20f3c9fb9892664da/inspect_ai-0.3.163.tar.gz", hash = "sha256:4a3b131a1d48430bf6d64ab9842fababf1ce66d64aa126f96ab09f399c4f9f61", size = 43358268, upload-time = "2026-01-21T20:36:44.792Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/ca/308f85765bcbc681bcb9f308a2443fb8a91d57b80e2e74f64e734cb6f946/inspect_ai-0.3.158-py3-none-any.whl", hash = "sha256:80d895b62c1cfea9cca20c54dffc2f7a23e822634b1cce45af7b03eb4fb92885", size = 34682434, upload-time = "2025-12-24T12:03:39.798Z" }, + { url = "https://files.pythonhosted.org/packages/5e/a0/bc25e3c895ff462f8b901784813a4241bdef8ed6aed66837f757a5e36747/inspect_ai-0.3.163-py3-none-any.whl", hash = "sha256:c09fd251d184a77f7a69fdd75695c457ed1c328fee4dafeabd9232f7309c6741", size = 34559953, upload-time = "2026-01-21T20:36:36.141Z" }, ] [[package]] @@ -3632,15 +3636,15 @@ wheels = [ [[package]] name = "latex2sympy2-extended" -version = "1.10.2" +version = "1.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "antlr4-python3-runtime" }, { name = "sympy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f4/de/472f9115c14c6f6d8a5889cabe3418283d708bde62ce00402c29441deed4/latex2sympy2_extended-1.10.2.tar.gz", hash = "sha256:41a517ffcc5a140e910a7d1646ce6ff440817e5f9d48fc8279d88bd0925bc389", size = 206188, upload-time = "2025-07-02T15:26:06.225Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/75/456da2da05f6380ea96e6ea804ab2c03e41fc3ed80052307fe8efe6ea20e/latex2sympy2_extended-1.11.0.tar.gz", hash = "sha256:9695657c81b50abba2636638638618db59f4663ed2a4a12d62cef74a40e28fec", size = 207023, upload-time = "2026-01-10T01:43:21.319Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/60/dfbbf40e3a371388c0e03ff65b01319b7d4023e883df6d7261125772ffdc/latex2sympy2_extended-1.10.2-py3-none-any.whl", hash = "sha256:f910442c5b02a466c1046f47d05cc5285181068b882399281f30102715337fb7", size = 207855, upload-time = "2025-07-02T15:26:04.88Z" }, + { url = "https://files.pythonhosted.org/packages/e9/61/f75cd1fa54d8434276126034aed54dd120747de9a8fa013cdd79545ccbeb/latex2sympy2_extended-1.11.0-py3-none-any.whl", hash = "sha256:aebb77d52ce269e25028e4bea89ddb14d242ba36bcf7b636496fb5fd9728d234", size = 209050, upload-time = "2026-01-10T01:43:19.458Z" }, ] [[package]] @@ -3950,14 +3954,14 @@ wheels = [ [[package]] name = "math-verify" -version = "0.8.0" +version = "0.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "latex2sympy2-extended" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/35/b5/b1db6fa6b6c28ebbe1889ee11a4703a72a2ca7750ec415f4559c758cf01a/math_verify-0.8.0.tar.gz", hash = "sha256:3295e0adb94bfe553ff6e3189c44f1916a85aa24ab5d1900f2086a706e28f7c4", size = 60191, upload-time = "2025-07-02T15:52:07.209Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/12/b8d13b581e110ac2f724a2351a8361a70fa36d057eb945d6379e8747c256/math_verify-0.9.0.tar.gz", hash = "sha256:45ac6c61344ba056b9e99a660a4bc8d044ed408f730aed68c60435aa5eec4645", size = 60329, upload-time = "2026-01-10T01:48:33.056Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/9f/59979f699b5c97334298f1295bc9fcdc9904d98d2276479bffff863d23b1/math_verify-0.8.0-py3-none-any.whl", hash = "sha256:31ca651296d817a9bb3fd58ca1fd0d192dcea709b1e5ecf2d0a4514c16f89087", size = 29994, upload-time = "2025-07-02T15:52:05.023Z" }, + { url = "https://files.pythonhosted.org/packages/62/76/6b4969bccc842b6567f7e6ee015684b9428a9b7fcbdf479e73716f43597f/math_verify-0.9.0-py3-none-any.whl", hash = "sha256:3703e7c4885354027fa84409d762a596a2906d1fd4deb78361876bd905a76194", size = 29967, upload-time = "2026-01-10T01:48:31.674Z" }, ] [[package]] @@ -5034,8 +5038,6 @@ dependencies = [ { name = "polars" }, { name = "setproctitle" }, { name = "tblib" }, - { name = "tinker" }, - { name = "tinker-cookbook" }, { name = "typer" }, { name = "weave" }, ] @@ -5093,6 +5095,12 @@ dev = [ { name = "ruff" }, { name = "ty" }, ] +tinker = [ + { name = "fastapi" }, + { name = "tinker" }, + { name = "tinker-cookbook" }, + { name = "uvicorn" }, +] [package.metadata] requires-dist = [ @@ -5121,8 +5129,6 @@ requires-dist = [ { name = "setuptools", marker = "extra == 'backend'", specifier = ">=78.1.0" }, { name = "skypilot", extras = ["cudo", "do", "fluidstack", "gcp", "lambda", "kubernetes", "paperspace", "runpod"], marker = "extra == 'skypilot'", specifier = "==0.10.5" }, { name = "tblib", specifier = ">=3.0.0" }, - { name = "tinker", specifier = ">=0.8.1" }, - { name = "tinker-cookbook", specifier = ">=0.1.0" }, { name = "torch", marker = "extra == 'backend'", specifier = ">=2.8.0" }, { name = "torchao", marker = "extra == 'backend'", specifier = "==0.14.1" }, { name = "transformers", marker = "extra == 'backend'", specifier = ">=4.55.2,<=4.57.3" }, @@ -5152,6 +5158,12 @@ dev = [ { name = "ruff", specifier = ">=0.12.1" }, { name = "ty", specifier = ">=0.0.14" }, ] +tinker = [ + { name = "fastapi", specifier = ">=0.128.0" }, + { name = "tinker", specifier = ">=0.8.1" }, + { name = "tinker-cookbook", specifier = ">=0.1.0" }, + { name = "uvicorn", specifier = ">=0.35.0" }, +] [[package]] name = "orjson" @@ -7946,11 +7958,11 @@ wheels = [ [[package]] name = "soupsieve" -version = "2.8.1" +version = "2.8.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/89/23/adf3796d740536d63a6fbda113d07e60c734b6ed5d3058d1e47fc0495e47/soupsieve-2.8.1.tar.gz", hash = "sha256:4cf733bc50fa805f5df4b8ef4740fc0e0fa6218cf3006269afd3f9d6d80fd350", size = 117856, upload-time = "2025-12-18T13:50:34.655Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/ae/2d9c981590ed9999a0d91755b47fc74f74de286b0f5cee14c9269041e6c4/soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349", size = 118627, upload-time = "2026-01-20T04:27:02.457Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/48/f3/b67d6ea49ca9154453b6d70b34ea22f3996b9fa55da105a79d8732227adc/soupsieve-2.8.1-py3-none-any.whl", hash = "sha256:a11fe2a6f3d76ab3cf2de04eb339c1be5b506a8a47f2ceb6d139803177f85434", size = 36710, upload-time = "2025-12-18T13:50:33.267Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, ] [[package]] @@ -8134,7 +8146,7 @@ wheels = [ [[package]] name = "textual" -version = "6.12.0" +version = "7.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py", extra = ["linkify"] }, @@ -8144,9 +8156,9 @@ dependencies = [ { name = "rich" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/39/55/29416ef63de4c37b37da217b94439a28496a4dc585209f5bf1437a61d120/textual-6.12.0.tar.gz", hash = "sha256:a32e8edbf6abdb0c42d486e96bdf419eb3aa378edb1b1271b84637f3dbd64c73", size = 1584182, upload-time = "2026-01-02T09:42:30.415Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/ee/620c887bfad9d6eba062dfa3b6b0e735e0259102e2667b19f21625ef598d/textual-7.3.0.tar.gz", hash = "sha256:3169e8ba5518a979b0771e60be380ab1a6c344f30a2126e360e6f38d009a3de4", size = 1590692, upload-time = "2026-01-15T16:32:02.342Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/13/f8/2a6a6ff1d07788f635493867d5a4003dfecacad16af1fdc9814d10daca3d/textual-6.12.0-py3-none-any.whl", hash = "sha256:cf9ea9a54d213c7736efe9fef440c7f49218d4e6ab75279afd060eded9c567ec", size = 714912, upload-time = "2026-01-02T09:42:28.786Z" }, + { url = "https://files.pythonhosted.org/packages/c3/1f/abeb4e5cb36b99dd37db72beb2a74d58598ccb35aaadf14624ee967d4a6b/textual-7.3.0-py3-none-any.whl", hash = "sha256:db235cecf969c87fe5a9c04d83595f506affc9db81f3a53ab849534d726d330a", size = 716374, upload-time = "2026-01-15T16:31:58.233Z" }, ] [[package]] @@ -8205,7 +8217,7 @@ wheels = [ [[package]] name = "tinker" -version = "0.8.1" +version = "0.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -8219,9 +8231,9 @@ dependencies = [ { name = "transformers" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ec/a2/fe985e880a4a0a7ee47a0fa4824f5a1d1b24c27b7dd8a32e2606838b42a0/tinker-0.8.1.tar.gz", hash = "sha256:5ccf49f2ad7ca2dade303e9e7c33a1d21e14948d051e2a70e4a4f89d4fa52abf", size = 171213, upload-time = "2026-01-21T23:13:42.738Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/f6/5ad5faa0e34f18f4d921fb1afab038b4fa2f662cd199b78072c486840d75/tinker-0.9.0.tar.gz", hash = "sha256:d47dd5870f3f3c982ad515254903753ccb626a2ee096e72ebaa90beb24573926", size = 173090, upload-time = "2026-01-26T22:33:58.998Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/70/d6/41d8c5ad6145f360f7bb476ff69cfa836cd961ffe7492d7ab3fa0da46bc1/tinker-0.8.1-py3-none-any.whl", hash = "sha256:ec743006d596735e8197b3c43adb03bdb46a5cbc01c02b920a0b7f60743c3377", size = 168164, upload-time = "2026-01-21T23:13:44.192Z" }, + { url = "https://files.pythonhosted.org/packages/e7/00/0282156cf66331e3f2dc0f8cb7020886fdbe6843771d3afac810c94f2638/tinker-0.9.0-py3-none-any.whl", hash = "sha256:e7c4a476a3c68799654021807cd9e1a4b3954f664b30f60fe613caeb774d7f94", size = 168536, upload-time = "2026-01-26T22:33:57.478Z" }, ] [[package]] @@ -8792,15 +8804,15 @@ wheels = [ [[package]] name = "universal-pathlib" -version = "0.3.7" +version = "0.3.8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fsspec" }, { name = "pathlib-abc" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d5/96/b58b00ce27cbc37fd3c79944438dd8630d2c39f9467c9e73e1a4a3eec1ef/universal_pathlib-0.3.7.tar.gz", hash = "sha256:36331056fa59a7d7cd3b61b4045f3a3418f446f23ec1a01d281c4510814b4b05", size = 253466, upload-time = "2025-12-03T00:06:43.859Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/ec/764b0d4593c6a8f5f66b347a19b5db9486dd0024b5e3339d468064a90c76/universal_pathlib-0.3.8.tar.gz", hash = "sha256:ead2b65bca3df6e11c3b7cb36fc9846340bc3c2db4ef57131550260422b0a3e8", size = 258837, upload-time = "2026-01-11T22:13:53.328Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/79/77/53c2d3a0413bc55b4c91067a0c41e55451be9b0784d204e4e46ce5abf668/universal_pathlib-0.3.7-py3-none-any.whl", hash = "sha256:fb95117b20b5981f86ef9d887fddbf9c61d3596634ba42cccea444931d87c201", size = 80738, upload-time = "2025-12-03T00:06:41.997Z" }, + { url = "https://files.pythonhosted.org/packages/86/2c/fc9416619a418e94576aef84ef263906a24f76a21a1c3e96ddae25c82df9/universal_pathlib-0.3.8-py3-none-any.whl", hash = "sha256:dac4fd9a3df918d85bb6da678e794b5dfa9ecdb5ff74675b497553dbe50134b8", size = 82608, upload-time = "2026-01-11T22:13:51.313Z" }, ] [[package]]