diff --git a/pyproject.toml b/pyproject.toml
index 1a50198af..46a7e8bd6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,8 +9,6 @@ dependencies = [
"typer>=0.15.2",
"litellm>=1.71.1",
"weave>=0.52.23",
- "tinker>=0.8.1",
- "tinker-cookbook>=0.1.0",
"polars>=1.26.0",
"tblib>=3.0.0",
"nest-asyncio>=1.6.0",
@@ -115,6 +113,9 @@ unused-ignore-comment = "ignore"
# Allow unresolved imports for optional dependencies that may not be installed locally.
# In CI, we install all optional deps so these will be resolved and type-checked.
allowed-unresolved-imports = [
+ # tinker deps
+ "tinker.**",
+ "tinker_cookbook.**",
# backend deps
"accelerate.**",
"awscli.**",
@@ -165,6 +166,12 @@ dev = [
"pyarrow>=15.0.0",
"prek>=0.2.29",
]
+tinker = [
+ "fastapi>=0.128.0",
+ "tinker>=0.8.1",
+ "tinker-cookbook>=0.1.0",
+ "uvicorn>=0.35.0",
+]
[tool.uv.sources]
panza = { git = "https://github.com/corbt/panza.git" }
diff --git a/src/art/__init__.py b/src/art/__init__.py
index b6948f514..d07d20274 100644
--- a/src/art/__init__.py
+++ b/src/art/__init__.py
@@ -57,7 +57,13 @@ def __init__(self, **kwargs):
from .local import LocalBackend
from .model import Model, TrainableModel
from .serverless import ServerlessBackend
-from .tinker import TinkerBackend
+
+try:
+ from .tinker import TinkerBackend
+ from .tinker_native import TinkerNativeBackend
+except ModuleNotFoundError:
+ TinkerBackend = None # type: ignore[assignment]
+ TinkerNativeBackend = None # type: ignore[assignment]
from .trajectories import Trajectory, TrajectoryGroup
from .types import (
LocalTrainResult,
@@ -91,9 +97,10 @@ def __init__(self, **kwargs):
"retry",
"TrainConfig",
"TrainResult",
- "TinkerBackend",
"Trajectory",
"TrajectoryGroup",
"capture_yielded_trajectory",
"yield_trajectory",
]
+if TinkerBackend is not None:
+ __all__.extend(["TinkerBackend", "TinkerNativeBackend"])
diff --git a/src/art/dev/__init__.py b/src/art/dev/__init__.py
index 45e1cdef6..9d04c26bd 100644
--- a/src/art/dev/__init__.py
+++ b/src/art/dev/__init__.py
@@ -4,6 +4,7 @@
InternalModelConfig,
PeftArgs,
TinkerArgs,
+ TinkerNativeArgs,
TinkerTrainingClientArgs,
TrainerArgs,
)
@@ -16,6 +17,7 @@
"InitArgs",
"PeftArgs",
"TinkerArgs",
+ "TinkerNativeArgs",
"TinkerTrainingClientArgs",
"TrainerArgs",
"get_openai_server_config",
diff --git a/src/art/dev/model.py b/src/art/dev/model.py
index 68316233f..8bd342b81 100644
--- a/src/art/dev/model.py
+++ b/src/art/dev/model.py
@@ -121,6 +121,7 @@ class InternalModelConfig(TypedDict, total=False):
engine_args: "EngineArgs"
peft_args: "PeftArgs"
tinker_args: "TinkerArgs | None"
+ tinker_native_args: "TinkerNativeArgs | None"
trainer_args: "TrainerArgs"
@@ -129,6 +130,11 @@ class TinkerArgs(TypedDict, total=False):
training_client_args: "TinkerTrainingClientArgs"
+class TinkerNativeArgs(TypedDict, total=False):
+ renderer_name: Required[str]
+ training_client_args: "TinkerTrainingClientArgs"
+
+
class TinkerTrainingClientArgs(TypedDict, total=False):
rank: int
seed: int | None
diff --git a/src/art/local/backend.py b/src/art/local/backend.py
index 19f26afbe..4d46df28b 100644
--- a/src/art/local/backend.py
+++ b/src/art/local/backend.py
@@ -102,6 +102,8 @@ async def register(
Args:
model: An art.Model instance.
"""
+ # Ensure model state/logging uses the backend path
+ model.base_path = self._path
output_dir = get_model_dir(model=model, art_path=self._path)
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/model.json", "w") as f:
diff --git a/src/art/model.py b/src/art/model.py
index 2fc38640b..05f55bd67 100644
--- a/src/art/model.py
+++ b/src/art/model.py
@@ -261,18 +261,22 @@ def _get_output_dir(self) -> str:
"""Get the output directory for this model."""
return f"{self.base_path}/{self.project}/models/{self.name}"
- def write_state(self, state: StateType) -> None:
- """Write persistent state to the model directory as JSON.
+ def overwrite_state(self, state: StateType) -> None:
+ """Overwrite persistent state in the model directory as JSON.
This state is stored in `state.json` within the model's output directory
and can be used to track training progress, dataset position, or any
other information that should persist across runs.
+ Warning:
+ This overwrites the entire state file. Prefer `merge_state()` unless
+ you intentionally want to replace all existing keys.
+
Args:
state: A dictionary of JSON-serializable values to persist.
Example:
- model.write_state({
+ model.overwrite_state({
"step": 5,
"dataset_offset": 100,
"last_checkpoint_time": "2024-01-15T10:30:00",
@@ -283,6 +287,45 @@ def write_state(self, state: StateType) -> None:
with open(f"{output_dir}/state.json", "w") as f:
json.dump(state, f, indent=2)
+ def write_state(self, state: StateType) -> None:
+ """Deprecated: use `overwrite_state()` or `merge_state()` instead."""
+ warnings.warn(
+ "write_state() is deprecated. Use overwrite_state() or merge_state() instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.overwrite_state(state)
+
+ def merge_state(self, state: StateType) -> StateType:
+ """Deep-merge state into the existing state and persist it.
+
+ Args:
+ state: A dictionary of JSON-serializable values to merge.
+
+ Returns:
+ The merged state dictionary that was persisted.
+ """
+ existing = self.read_state() or {}
+ merged = self._deep_merge_dicts(existing, state)
+ self.overwrite_state(merged)
+ return cast(StateType, merged)
+
+ @staticmethod
+ def _deep_merge_dicts(
+ base: dict[str, Any], updates: dict[str, Any]
+ ) -> dict[str, Any]:
+ merged = dict(base)
+ for key, value in updates.items():
+ if (
+ key in merged
+ and isinstance(merged[key], dict)
+ and isinstance(value, dict)
+ ):
+ merged[key] = Model._deep_merge_dicts(merged[key], value)
+ else:
+ merged[key] = value
+ return merged
+
def read_state(self) -> StateType | None:
"""Read persistent state from the model directory.
diff --git a/src/art/pipeline_trainer/__init__.py b/src/art/pipeline_trainer/__init__.py
new file mode 100644
index 000000000..3e0d0f9ce
--- /dev/null
+++ b/src/art/pipeline_trainer/__init__.py
@@ -0,0 +1,13 @@
+from .status import StatusReporter
+from .trainer import PipelineTrainer, make_group_rollout_fn
+from .types import EvalFn, RolloutFn, ScenarioT, SingleRolloutFn
+
+__all__ = [
+ "PipelineTrainer",
+ "make_group_rollout_fn",
+ "StatusReporter",
+ "RolloutFn",
+ "SingleRolloutFn",
+ "EvalFn",
+ "ScenarioT",
+]
diff --git a/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py
new file mode 100644
index 000000000..70b413b66
--- /dev/null
+++ b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py
@@ -0,0 +1,346 @@
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass
+import json
+import os
+from pathlib import Path
+import re
+from typing import Any, cast
+
+from dotenv import load_dotenv
+from openai.types.chat.chat_completion_tool_choice_option_param import (
+ ChatCompletionToolChoiceOptionParam,
+)
+from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
+import polars as pl
+
+import art
+
+from . import PipelineTrainer, make_group_rollout_fn
+
+Scenario = dict[str, Any]
+
+
+@dataclass
+class PipelineConfig:
+ temperature: float
+ eval_temperature: float
+ max_tokens: int
+
+
+TOOL_NAME = "make_guess"
+SECRET_BITS = "110000110100101011111010101011"
+SECRET_LEN = len(SECRET_BITS)
+
+TOOLS: list[ChatCompletionToolParam] = [
+ cast(
+ ChatCompletionToolParam,
+ {
+ "type": "function",
+ "function": {
+ "name": TOOL_NAME,
+ "description": "Submit a binary guess for the secret string.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "guess": {
+ "type": "string",
+ "description": (
+ "A binary string of length "
+ f"{SECRET_LEN} consisting of 0 and 1."
+ ),
+ }
+ },
+ "required": ["guess"],
+ "additionalProperties": False,
+ },
+ },
+ },
+ )
+]
+TOOL_CHOICE: ChatCompletionToolChoiceOptionParam = {
+ "type": "function",
+ "function": {"name": TOOL_NAME},
+}
+
+SYSTEM_PROMPT = (
+ "You are playing a prefix-guessing game. You must call the tool "
+ f"{TOOL_NAME} exactly once with your best guess. The argument must be a "
+ f"{SECRET_LEN}-character string of only 0 and 1. Do not output any other text. "
+ "Your reward is the length of the shared prefix with the secret string."
+)
+USER_PROMPT = "Call the tool with your best binary guess."
+
+
+def is_valid_guess(guess: str) -> bool:
+ return all(ch in {"0", "1"} for ch in guess)
+
+
+def shared_prefix_len(guess: str, secret: str) -> int:
+ matched = 0
+ for guessed, actual in zip(guess, secret):
+ if guessed != actual:
+ break
+ matched += 1
+ return matched
+
+
+def _parse_guess_args(arguments: str | None) -> str | None:
+ if not arguments:
+ return None
+ text = arguments.strip()
+ if not text:
+ return None
+ payload: Any | None = None
+ try:
+ payload = json.loads(text)
+ except json.JSONDecodeError:
+ start = text.find("{")
+ end = text.rfind("}")
+ if start != -1 and end != -1 and end > start:
+ try:
+ payload = json.loads(text[start : end + 1])
+ except json.JSONDecodeError:
+ payload = None
+
+ if isinstance(payload, dict):
+ guess = payload.get("guess")
+ if isinstance(guess, str):
+ return guess
+ if guess is not None:
+ return str(guess)
+
+ match = re.search(r'guess\s*[:=]\s*"([^"]*)"', text)
+ if match:
+ return match.group(1)
+ match = re.search(r"guess\s*[:=]\s*'([^']*)'", text)
+ if match:
+ return match.group(1)
+ if re.fullmatch(r"[01\s]+", text):
+ return text
+ return None
+
+
+def _tool_name_and_args(tool_call: Any) -> tuple[str | None, str | None]:
+ if hasattr(tool_call, "function"):
+ function = tool_call.function
+ return getattr(function, "name", None), getattr(function, "arguments", None)
+ if isinstance(tool_call, dict):
+ func = tool_call.get("function") or {}
+ return func.get("name"), func.get("arguments")
+ return None, None
+
+
+def extract_guess(choice: Any) -> tuple[str | None, str]:
+ tool_calls = getattr(choice.message, "tool_calls", None) or []
+ for tool_call in tool_calls:
+ name, args = _tool_name_and_args(tool_call)
+ if name != TOOL_NAME:
+ continue
+ guess = _parse_guess_args(args)
+ if guess is not None:
+ return guess, "tool_call"
+
+ return None, "missing"
+
+
+def get_model_output_dir(model: art.TrainableModel) -> Path:
+ return Path(model.base_path) / model.project / "models" / model.name
+
+
+def print_history_summary(model: art.TrainableModel, tail: int = 5) -> None:
+ history_path = get_model_output_dir(model) / "history.jsonl"
+ if not history_path.exists():
+ print(f"No history found at {history_path}")
+ return
+
+ rows = pl.read_ndjson(str(history_path)).to_dicts()
+
+ train_rows = [row for row in rows if "train/reward" in row]
+ print("\nRecent training metrics:")
+ for row in train_rows[-tail:]:
+ step = row["step"]
+ reward = row["train/reward"]
+ std_dev = row["train/reward_std_dev"]
+ discarded = row["train/discarded_stale_samples"]
+ off_policy = row["train/steps_off_policy"]
+ print(
+ f" step={step} reward={reward} std={std_dev} "
+ f"discarded={discarded} off_policy={off_policy}"
+ )
+
+
+async def main() -> None:
+ load_dotenv()
+
+ base_model = os.environ.get(
+ "BASE_MODEL", "Qwen/Qwen3-4B-Instruct-2507"
+ ) # Qwen/Qwen3-30B-A3B-Instruct-2507
+ model_name = os.environ.get("MODEL_NAME", "pipeline-binary-prefix-tool")
+ project = os.environ.get("PROJECT", "binary-prefix-tool-pipeline")
+ art_path = os.environ.get("ART_PATH")
+
+ min_batch_size = int(os.environ.get("MIN_BATCH_SIZE", "4"))
+ num_rollout_workers = int(os.environ.get("NUM_ROLLOUT_WORKERS", "8"))
+ rollouts_per_scenario = int(os.environ.get("ROLLOUTS_PER_SCENARIO", "8"))
+ max_steps_off_policy = int(os.environ.get("MAX_STEPS_OFF_POLICY", "6"))
+ max_batch_size_env = os.environ.get("MAX_BATCH_SIZE")
+ max_batch_size = int(max_batch_size_env) if max_batch_size_env else None
+ eval_every_n_steps = int(os.environ.get("EVAL_EVERY_N_STEPS", "2"))
+ eval_step_0 = os.environ.get("EVAL_STEP_0", "1") == "1"
+ max_steps = int(os.environ.get("MAX_STEPS", "10"))
+ save_checkpoint = os.environ.get("SAVE_CHECKPOINT", "0") == "1"
+ resume_env = os.environ.get("RESUME")
+ resume = (resume_env == "1") if resume_env is not None else save_checkpoint
+ if resume and not save_checkpoint:
+ print("RESUME=1 but SAVE_CHECKPOINT=0; disabling resume for a clean run.")
+ resume = False
+
+ temperature = float(os.environ.get("ROLLOUT_TEMPERATURE", "1.0"))
+ eval_temperature = float(os.environ.get("EVAL_TEMPERATURE", "0.0"))
+ max_tokens = int(os.environ.get("MAX_TOKENS", "300"))
+ request_timeout = float(os.environ.get("REQUEST_TIMEOUT", "60"))
+ log_interval_seconds = float(os.environ.get("STATUS_LOG_INTERVAL_SECONDS", "60"))
+
+ internal_config: art.dev.InternalModelConfig | None = None
+ lora_rank = os.environ.get("LORA_RANK")
+ if lora_rank is not None:
+ internal_config = {
+ "tinker_native_args": {
+ "renderer_name": os.environ.get("RENDERER_NAME", "qwen3_instruct"),
+ "training_client_args": {"rank": int(lora_rank)},
+ }
+ }
+
+ backend = art.TinkerNativeBackend(path=art_path)
+ model = art.TrainableModel(
+ name=model_name,
+ project=project,
+ base_model=base_model,
+ _internal_config=internal_config,
+ report_metrics=[], # Disable wandb logging
+ )
+ await model.register(backend)
+
+ openai_client = model.openai_client()
+
+ async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory:
+ """Core rollout logic used by both training and eval."""
+ messages: art.Messages = scenario["messages"]
+ response = await openai_client.chat.completions.create(
+ messages=messages,
+ model=model.name,
+ max_tokens=max_tokens,
+ timeout=request_timeout,
+ temperature=temp,
+ tools=TOOLS,
+ tool_choice=TOOL_CHOICE,
+ )
+ choice = response.choices[0]
+ raw_guess, source = extract_guess(choice)
+ guess = raw_guess or ""
+ valid_guess = is_valid_guess(guess)
+ prefix_len = shared_prefix_len(guess, SECRET_BITS) if valid_guess else 0
+ reward = float(prefix_len)
+ metrics = {
+ "prefix_len": prefix_len,
+ "guess_len": len(guess),
+ "secret_len": SECRET_LEN,
+ "valid_guess": 1.0 if valid_guess else 0.0,
+ "tool_call_count": float(
+ len(getattr(choice.message, "tool_calls", None) or [])
+ ),
+ "tool_call_found": 1.0 if source != "missing" else 0.0,
+ "tool_call_structured": 1.0 if source == "tool_call" else 0.0,
+ }
+ return art.Trajectory(
+ messages_and_choices=[*messages, choice],
+ tools=TOOLS,
+ reward=reward,
+ metrics=metrics,
+ )
+
+ async def single_rollout(
+ _model: art.TrainableModel,
+ scenario: Scenario,
+ _config: PipelineConfig,
+ ) -> art.Trajectory:
+ return await do_rollout(scenario, temperature)
+
+ rollout_fn = make_group_rollout_fn(single_rollout, n=rollouts_per_scenario)
+
+ last_eval: dict[str, float | None] = {"avg_reward": None}
+
+ async def eval_fn(
+ _model: art.TrainableModel, _step: int, _config: PipelineConfig
+ ) -> list[art.Trajectory]:
+ tasks = [do_rollout(build_scenario(), eval_temperature)]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ trajectories = [r for r in results if isinstance(r, art.Trajectory)]
+ if trajectories:
+ avg_reward = sum(t.reward for t in trajectories) / len(trajectories)
+ last_eval["avg_reward"] = avg_reward
+ return trajectories
+
+ scenario_count = int(os.environ.get("SCENARIO_COUNT", "1000"))
+ scenario_count = max(1, scenario_count)
+
+ def build_scenario() -> Scenario:
+ return {
+ "messages": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": USER_PROMPT},
+ ],
+ }
+
+ async def scenario_iter():
+ for i in range(scenario_count):
+ scenario = build_scenario()
+ scenario["metadata"] = {"scenario_idx": i}
+ yield scenario
+
+ config = PipelineConfig(
+ temperature=temperature,
+ eval_temperature=eval_temperature,
+ max_tokens=max_tokens,
+ )
+
+ trainer = PipelineTrainer(
+ model=model,
+ backend=backend,
+ rollout_fn=rollout_fn,
+ scenarios=scenario_iter(),
+ config=config,
+ eval_fn=eval_fn,
+ num_rollout_workers=num_rollout_workers,
+ min_batch_size=min_batch_size,
+ max_steps_off_policy=max_steps_off_policy,
+ max_batch_size=max_batch_size,
+ learning_rate=float(os.environ.get("LEARNING_RATE", "1e-4")),
+ log_interval_seconds=log_interval_seconds,
+ eval_every_n_steps=eval_every_n_steps,
+ eval_step_0=eval_step_0,
+ save_checkpoint=save_checkpoint,
+ resume=resume,
+ max_steps=max_steps,
+ total_scenarios=scenario_count,
+ )
+
+ print(
+ "Starting pipeline trainer test: "
+ f"max_steps={max_steps} scenarios={scenario_count} "
+ f"rollouts_per_scenario={rollouts_per_scenario} "
+ f"secret_len={SECRET_LEN}"
+ )
+ await trainer.train()
+ print("Training completed.")
+ if last_eval["avg_reward"] is not None:
+ print(f"Last eval avg reward: {last_eval['avg_reward']:.3f}")
+
+ print_history_summary(model, tail=5)
+ await backend.close()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/src/art/pipeline_trainer/state.py b/src/art/pipeline_trainer/state.py
new file mode 100644
index 000000000..46f8a9cb0
--- /dev/null
+++ b/src/art/pipeline_trainer/state.py
@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass, field
+
+
+@dataclass
+class PipelineState:
+ """Shared state across pipeline stages."""
+
+ # Policy versioning
+ policy_version: int = 0
+ next_training_step: int = 0
+
+ # Scenario tracking
+ scenario_offset: int = 0
+ total_scenarios_consumed: int = 0
+ last_eval_step: int = 0
+
+ # Metrics
+ discarded_stale_samples: int = 0
+
+ # Synchronization
+ policy_updated: asyncio.Condition = field(default_factory=asyncio.Condition)
+ done: bool = False
diff --git a/src/art/pipeline_trainer/status.py b/src/art/pipeline_trainer/status.py
new file mode 100644
index 000000000..cb58bdb3e
--- /dev/null
+++ b/src/art/pipeline_trainer/status.py
@@ -0,0 +1,407 @@
+from __future__ import annotations
+
+import math
+import shutil
+import sys
+import time
+from typing import Callable, cast
+
+from tqdm import tqdm
+
+from art import TrajectoryGroup
+
+
+class StatusReporter:
+ def __init__(
+ self,
+ *,
+ get_scenario_offset: Callable[[], int],
+ log_interval_seconds: float = 60.0,
+ status_ewa_alpha: float = 0.2,
+ total_scenarios: int | None = None,
+ num_workers: int = 1,
+ ) -> None:
+ if log_interval_seconds <= 0:
+ raise ValueError("log_interval_seconds must be > 0")
+ if not 0 < status_ewa_alpha <= 1:
+ raise ValueError("status_ewa_alpha must be in (0, 1]")
+ if total_scenarios is not None and total_scenarios < 0:
+ raise ValueError("total_scenarios must be >= 0")
+ if num_workers <= 0:
+ raise ValueError("num_workers must be > 0")
+
+ self._get_scenario_offset = get_scenario_offset
+ self._log_interval_seconds = log_interval_seconds
+ self._status_ewa_alpha = status_ewa_alpha
+ self._total_scenarios = total_scenarios
+ self._num_workers = num_workers
+
+ self._current_step: int | None = None
+ self._rolling_out = 0
+ self._queued = 0
+ self._training = 0
+ self._trained = 0
+ self._stale = 0
+ self._zero_var = 0
+ self._errored = 0
+
+ self._train_reward_ewa: float | None = None
+ self._seconds_ewa: float | None = None
+ self._avg_std_dev_ewa: float | None = None
+
+ self._last_val_step: int | None = None
+ self._last_val_reward: float | None = None
+ self._val_running_step: int | None = None
+
+ self._tqdm: tqdm | None = None
+ self._started = False
+ self._last_log_time = 0.0
+ self._last_refresh_time = 0.0
+ self._refresh_interval_seconds = 0.25
+
+ def start(self, *, initial_step: int | None = None) -> None:
+ if self._started:
+ return
+ self._started = True
+ if initial_step is not None:
+ self._current_step = initial_step
+ self._last_log_time = time.monotonic() - self._log_interval_seconds
+ self._last_refresh_time = 0.0
+ if sys.stdout.isatty():
+ self._tqdm = tqdm(
+ total=self._total_scenarios,
+ bar_format="{desc}",
+ dynamic_ncols=True,
+ leave=False,
+ file=sys.stdout,
+ )
+ self._refresh_status(force=True)
+
+ def close(self) -> None:
+ if self._tqdm is not None:
+ self._tqdm.close()
+ self._tqdm = None
+ self._started = False
+
+ def flush(self) -> None:
+ self.log_if_due(force=True)
+
+ def set_step(self, step: int) -> None:
+ self._current_step = step
+ self._refresh_status()
+
+ def log_if_due(self, *, force: bool = False) -> None:
+ if not self._started:
+ return
+ now = time.monotonic()
+ if not force and (now - self._last_log_time) < self._log_interval_seconds:
+ return
+ self._last_log_time = now
+ self._write_log_line(self._format_full_log())
+
+ def note_rollout_started(self) -> None:
+ self._rolling_out += 1
+ self._refresh_status()
+
+ def note_rollout_finished(self, *, errored: bool) -> None:
+ if self._rolling_out > 0:
+ self._rolling_out -= 1
+ if errored:
+ self._errored += 1
+ self._refresh_status()
+
+ def note_group_enqueued(self, _group: TrajectoryGroup) -> None:
+ self._queued += 1
+ self._refresh_status()
+
+ def note_group_dequeued(self, _group: TrajectoryGroup) -> None:
+ if self._queued > 0:
+ self._queued -= 1
+ self._refresh_status()
+
+ def note_stale(self, count: int) -> None:
+ if count > 0:
+ self._stale += count
+ self._refresh_status()
+
+ def note_zero_variance_discarded(self, count: int) -> None:
+ if count > 0:
+ self._zero_var += count
+ self._refresh_status()
+
+ def note_training_start(self, group_count: int) -> None:
+ self._training = group_count
+ self._refresh_status()
+
+ def note_training_end(self) -> None:
+ self._training = 0
+ self._refresh_status()
+
+ def note_training_batch(
+ self, batch: list[TrajectoryGroup], *, step: int, step_seconds: float
+ ) -> None:
+ zero_variance_groups = self._count_zero_variance_groups(batch)
+ trainable_groups = len(batch) - zero_variance_groups
+ avg_std = self._compute_batch_avg_std_dev(batch)
+ avg_reward = self._compute_batch_avg_reward(batch)
+
+ self._current_step = step
+ self._trained += trainable_groups
+ self._zero_var += zero_variance_groups
+
+ if avg_reward is not None:
+ self._train_reward_ewa = self._update_ewa(
+ self._train_reward_ewa, avg_reward
+ )
+ self._seconds_ewa = self._update_ewa(self._seconds_ewa, step_seconds)
+ self._avg_std_dev_ewa = self._update_ewa(self._avg_std_dev_ewa, avg_std)
+ self._refresh_status(force=True)
+
+ def note_val_started(self, step: int) -> None:
+ self._val_running_step = step
+ self._refresh_status(force=True)
+
+ def note_val_finished(self, step: int, reward: float | None) -> None:
+ self._last_val_step = step
+ self._last_val_reward = reward
+ if self._val_running_step == step:
+ self._val_running_step = None
+ self._refresh_status(force=True)
+
+ def _build_snapshot(self) -> dict[str, object]:
+ remaining = None
+ if self._total_scenarios is not None:
+ remaining = max(self._total_scenarios - self._get_scenario_offset(), 0)
+ return {
+ "step": self._current_step,
+ "remaining": remaining,
+ "rolling": self._rolling_out,
+ "workers": self._num_workers,
+ "queued": self._queued,
+ "training": self._training,
+ "trained": self._trained,
+ "zero_var": self._zero_var,
+ "stale": self._stale,
+ "errored": self._errored,
+ "discarded": self._zero_var + self._stale + self._errored,
+ "train_reward_ewa": self._train_reward_ewa,
+ "train_seconds_ewa": self._seconds_ewa,
+ "train_avg_std_ewa": self._avg_std_dev_ewa,
+ "val_step": self._last_val_step,
+ "val_reward": self._last_val_reward,
+ "val_running": self._val_running_step,
+ }
+
+ def _format_condensed_line(self) -> str:
+ snapshot = self._build_snapshot()
+ trained = cast(int, snapshot["trained"])
+ discarded = cast(int, snapshot["discarded"])
+ remaining = snapshot["remaining"]
+
+ scenarios_fields = [
+ "scenarios",
+ f"tr={trained}",
+ ]
+ if remaining is not None:
+ scenarios_fields.append(f"r={self._fmt_int_compact(remaining)}")
+ scenarios_fields.extend(
+ [
+ f"q={self._fmt_int_compact(snapshot['queued'])}",
+ f"b={self._fmt_int_compact(snapshot['training'])}",
+ f"d={discarded}",
+ ]
+ )
+
+ train_fields = [
+ "train",
+ f"s={self._fmt_int_compact(snapshot['step'])}",
+ f"r={self._fmt_float_compact(snapshot['train_reward_ewa'], 2)}",
+ f"dt={self._fmt_float_compact(snapshot['train_seconds_ewa'], 1)}",
+ f"sd={self._fmt_float_compact(snapshot['train_avg_std_ewa'], 2)}",
+ ]
+
+ val_run = "y" if snapshot["val_running"] is not None else "n"
+ val_fields = [
+ "val",
+ f"r={self._fmt_float_compact(snapshot['val_reward'], 2)}",
+ f"act={val_run}",
+ ]
+
+ def build_line() -> str:
+ return " ".join(
+ [
+ f"scenarios[{' '.join(scenarios_fields[1:])}]",
+ f"train[{' '.join(train_fields[1:])}]",
+ f"val[{' '.join(val_fields[1:])}]",
+ ]
+ )
+
+ line = build_line()
+ max_width = shutil.get_terminal_size(fallback=(120, 20)).columns
+ if len(line) <= max_width:
+ return line
+
+ train_fields = [field for field in train_fields if not field.startswith("sd=")]
+ line = build_line()
+ if len(line) <= max_width:
+ return line
+
+ val_fields = [field for field in val_fields if not field.startswith("act=")]
+ line = build_line()
+ if len(line) <= max_width:
+ return line
+
+ if remaining is not None:
+ scenarios_fields = [
+ field for field in scenarios_fields if not field.startswith("r=")
+ ]
+ return build_line()
+
+ def _format_full_log(self) -> str:
+ snapshot = self._build_snapshot()
+ scenarios_fields = [
+ "scenarios",
+ f"trained={self._fmt_int(snapshot['trained'])}",
+ ]
+ if snapshot["remaining"] is not None:
+ scenarios_fields.append(f"remaining={self._fmt_int(snapshot['remaining'])}")
+ scenarios_fields.extend(
+ [
+ f"queued={self._fmt_int(snapshot['queued'])}",
+ f"training={self._fmt_int(snapshot['training'])}",
+ (
+ "discarded["
+ f"total={self._fmt_int(snapshot['discarded'])} "
+ f"0_var={self._fmt_int(snapshot['zero_var'])} "
+ f"stale={self._fmt_int(snapshot['stale'])} "
+ f"errored={self._fmt_int(snapshot['errored'])}"
+ "]"
+ ),
+ f"rollouts={snapshot['rolling']}/{snapshot['workers']}",
+ ]
+ )
+
+ train_fields = [
+ "train",
+ f"step={self._fmt_int(snapshot['step'])}",
+ f"reward={self._fmt_float(snapshot['train_reward_ewa'], 3)}",
+ f"step_seconds={self._fmt_float(snapshot['train_seconds_ewa'], 2)}",
+ f"avg_std={self._fmt_float(snapshot['train_avg_std_ewa'], 3)}",
+ ]
+
+ val_run = "yes" if snapshot["val_running"] is not None else "no"
+ val_fields = [
+ "val",
+ f"reward={self._fmt_float(snapshot['val_reward'], 3)}",
+ f"active={val_run}",
+ ]
+ if snapshot["val_step"] is not None:
+ val_fields.append(f"step={snapshot['val_step']}")
+ if snapshot["val_running"] is not None:
+ val_fields.append(f"active_step={snapshot['val_running']}")
+
+ return "[status] " + " ".join(
+ [
+ f"scenarios[{' '.join(scenarios_fields[1:])}]",
+ f"train[{' '.join(train_fields[1:])}]",
+ f"val[{' '.join(val_fields[1:])}]",
+ ]
+ )
+
+ def _write_log_line(self, line: str) -> None:
+ if self._tqdm is not None:
+ self._tqdm.write(line)
+ else:
+ print(line)
+
+ def _refresh_status(self, *, force: bool = False) -> None:
+ if self._tqdm is None:
+ return
+ now = time.monotonic()
+ if (
+ not force
+ and (now - self._last_refresh_time) < self._refresh_interval_seconds
+ ):
+ return
+ self._tqdm.set_description_str(self._format_condensed_line())
+ self._last_refresh_time = now
+
+ def _count_zero_variance_groups(self, batch: list[TrajectoryGroup]) -> int:
+ return sum(1 for group in batch if self._group_zero_variance(group))
+
+ def _group_zero_variance(self, group: TrajectoryGroup) -> bool:
+ rewards = [t.reward for t in group.trajectories]
+ if len(rewards) <= 1:
+ return True
+ first = rewards[0]
+ return all(abs(r - first) <= 1e-12 for r in rewards[1:])
+
+ def _compute_group_std_dev(self, group: TrajectoryGroup) -> float:
+ rewards = [t.reward for t in group.trajectories]
+ if len(rewards) <= 1:
+ return 0.0
+ mean = sum(rewards) / len(rewards)
+ variance = sum((r - mean) ** 2 for r in rewards) / len(rewards)
+ return math.sqrt(variance)
+
+ def _compute_batch_avg_std_dev(self, batch: list[TrajectoryGroup]) -> float:
+ if not batch:
+ return 0.0
+ std_devs = [self._compute_group_std_dev(group) for group in batch]
+ return sum(std_devs) / len(std_devs)
+
+ def _compute_batch_avg_reward(self, batch: list[TrajectoryGroup]) -> float | None:
+ rewards = [t.reward for group in batch for t in group.trajectories]
+ if not rewards:
+ return None
+ return sum(rewards) / len(rewards)
+
+ def _update_ewa(self, previous: float | None, new_value: float) -> float:
+ if previous is None:
+ return new_value
+ alpha = self._status_ewa_alpha
+ return alpha * new_value + (1 - alpha) * previous
+
+ @staticmethod
+ def _format_count(value: int) -> str:
+ if value >= 1_000_000:
+ return StatusReporter._format_scaled(value, 1_000_000, "m")
+ if value >= 1_000:
+ return StatusReporter._format_scaled(value, 1_000, "k")
+ return str(value)
+
+ @staticmethod
+ def _format_scaled(value: int, scale: int, suffix: str) -> str:
+ scaled = value / scale
+ text = f"{scaled:.1f}"
+ if text.endswith(".0"):
+ text = text[:-2]
+ return f"{text}{suffix}"
+
+ @staticmethod
+ def _fmt_int(value: object) -> str:
+ if value is None:
+ return "n/a"
+ return str(value)
+
+ @staticmethod
+ def _fmt_int_compact(value: object) -> str:
+ if value is None:
+ return "na"
+ return str(value)
+
+ @staticmethod
+ def _fmt_float(value: object, decimals: int) -> str:
+ if value is None:
+ return "n/a"
+ if not isinstance(value, (int, float)):
+ return str(value)
+ return f"{value:.{decimals}f}"
+
+ @staticmethod
+ def _fmt_float_compact(value: object, decimals: int) -> str:
+ if value is None:
+ return "na"
+ if not isinstance(value, (int, float)):
+ return str(value)
+ return f"{value:.{decimals}f}"
diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py
new file mode 100644
index 000000000..4bb062d4c
--- /dev/null
+++ b/src/art/pipeline_trainer/trainer.py
@@ -0,0 +1,726 @@
+from __future__ import annotations
+
+import asyncio
+import os
+import signal
+import time
+from typing import Any, AsyncIterator, Generic, Iterable, TypeVar, cast
+
+T = TypeVar("T")
+
+import art
+from art import TrajectoryGroup
+
+from .state import PipelineState
+from .status import StatusReporter
+from .types import ConfigT, EvalFn, RolloutFn, ScenarioT, SingleRolloutFn # noqa: F401
+
+PIPELINE_STATE_KEY = "_pipeline_trainer"
+
+
+def _to_async_iterator(iterable: Iterable[T] | AsyncIterator[T]) -> AsyncIterator[T]:
+ """Convert a sync Iterable to an AsyncIterator, or pass through if already async."""
+ if isinstance(iterable, AsyncIterator):
+ return iterable
+
+ async def _iter():
+ for item in iterable:
+ yield item
+
+ return _iter()
+
+
+def make_group_rollout_fn(
+ single_rollout_fn: SingleRolloutFn[ScenarioT, ConfigT],
+ n: int = 4,
+) -> RolloutFn[ScenarioT, ConfigT]:
+ """Create a RolloutFn from a SingleRolloutFn by running it N times in parallel."""
+
+ async def group_rollout(
+ model: art.TrainableModel,
+ scenario: ScenarioT,
+ config: ConfigT,
+ ) -> TrajectoryGroup:
+ if n <= 0:
+ return TrajectoryGroup([])
+ results = await asyncio.gather(
+ *[single_rollout_fn(model, scenario, config) for _ in range(n)],
+ return_exceptions=True,
+ )
+ return TrajectoryGroup(results)
+
+ return group_rollout
+
+
+class PipelineTrainer(Generic[ScenarioT, ConfigT]):
+ """Async 3-stage pipeline for rollouts, training, and eval."""
+
+ def __init__(
+ self,
+ model: art.TrainableModel,
+ backend: art.Backend,
+ rollout_fn: RolloutFn[ScenarioT, ConfigT],
+ scenarios: AsyncIterator[ScenarioT] | Iterable[ScenarioT],
+ config: ConfigT,
+ eval_fn: EvalFn[ConfigT] | None = None,
+ *,
+ # Pipeline settings
+ num_rollout_workers: int = 16,
+ min_batch_size: int = 4,
+ max_batch_size: int | None = None,
+ max_steps_off_policy: int = 4,
+ queue_maxsize: int | None = None,
+ # Training
+ learning_rate: float = 1e-5,
+ loss_fn: str = "cispo",
+ loss_fn_config: dict | None = None,
+ normalize_advantages: bool = True,
+ adam_params: object | None = None,
+ max_steps: int | None = None,
+ # Discard handling
+ discard_queue_multiplier: int = 100,
+ # Status output
+ log_interval_seconds: float = 60.0,
+ status_ewa_alpha: float = 0.2,
+ total_scenarios: int | None = None,
+ # Eval/Checkpointing
+ eval_every_n_steps: int = 20,
+ eval_step_0: bool = True,
+ save_checkpoint: bool = True,
+ # Resumption
+ resume: bool = True,
+ ) -> None:
+ if num_rollout_workers <= 0:
+ raise ValueError("num_rollout_workers must be > 0")
+ if min_batch_size <= 0:
+ raise ValueError("min_batch_size must be > 0")
+ if max_batch_size is not None and max_batch_size <= 0:
+ raise ValueError("max_batch_size must be > 0")
+ if max_batch_size is not None and max_batch_size < min_batch_size:
+ raise ValueError("max_batch_size must be >= min_batch_size")
+ if max_steps_off_policy < 0:
+ raise ValueError("max_steps_off_policy must be >= 0")
+ if queue_maxsize is not None and queue_maxsize <= 0:
+ raise ValueError("queue_maxsize must be > 0")
+ if eval_every_n_steps < 0:
+ raise ValueError("eval_every_n_steps must be >= 0")
+ if max_steps is not None and max_steps < 0:
+ raise ValueError("max_steps must be >= 0")
+ if log_interval_seconds <= 0:
+ raise ValueError("log_interval_seconds must be > 0")
+ if discard_queue_multiplier <= 0:
+ raise ValueError("discard_queue_multiplier must be > 0")
+ self.model = model
+ self.backend = backend
+ self.rollout_fn = rollout_fn
+ self.config = config
+ self.eval_fn = eval_fn
+ self.num_rollout_workers = num_rollout_workers
+ self.min_batch_size = min_batch_size
+ self.max_batch_size = (
+ max_batch_size if max_batch_size is not None else 10 * min_batch_size
+ )
+ self.max_steps_off_policy = max_steps_off_policy
+ self.queue_maxsize = queue_maxsize
+ self.learning_rate = learning_rate
+ self.loss_fn = loss_fn
+ self.loss_fn_config = loss_fn_config
+ self.normalize_advantages = normalize_advantages
+ self.adam_params = adam_params
+ self.max_steps = max_steps
+ self._status_log_interval_seconds = log_interval_seconds
+ self.eval_every_n_steps = eval_every_n_steps
+ self.eval_step_0 = eval_step_0
+ self.save_checkpoint = save_checkpoint
+ self.resume = resume
+ self.discard_queue_multiplier = discard_queue_multiplier
+ self._discard_queue: list[TrajectoryGroup] = []
+ self._discard_queue_limit = discard_queue_multiplier * min_batch_size
+ self._collapse_triggered = False
+
+ self.state = PipelineState()
+ self._scenario_lock = asyncio.Lock()
+ self._scenario_iter: AsyncIterator[ScenarioT] | None = _to_async_iterator(
+ scenarios
+ )
+ self._output_queue: asyncio.Queue[TrajectoryGroup | None] | None = None
+ self._eval_queue: asyncio.Queue[int] | None = None
+ self._status = StatusReporter(
+ get_scenario_offset=lambda: self.state.scenario_offset,
+ log_interval_seconds=log_interval_seconds,
+ status_ewa_alpha=status_ewa_alpha,
+ total_scenarios=total_scenarios,
+ num_workers=num_rollout_workers,
+ )
+
+ async def train(self, *, handle_signals: bool = True) -> None:
+ """Run the training pipeline over the configured scenario iterator."""
+ start_step = await self.model.get_step()
+ pipeline_state = self._read_pipeline_state() if self.resume else {}
+ scenario_offset = int(pipeline_state.get("scenario_offset", 0) or 0)
+ last_eval_step = int(pipeline_state.get("last_eval_step", 0) or 0)
+ stored_step = pipeline_state.get("training_step")
+
+ if stored_step is not None and int(stored_step) != start_step:
+ print(
+ "Warning: pipeline trainer state step does not match backend step "
+ f"({stored_step} != {start_step}); using backend step."
+ )
+
+ self.state.policy_version = start_step
+ self.state.next_training_step = start_step
+ self.state.scenario_offset = scenario_offset
+ self.state.total_scenarios_consumed = int(
+ pipeline_state.get("total_scenarios_consumed", scenario_offset) or 0
+ )
+ self.state.last_eval_step = last_eval_step
+
+ if scenario_offset > 0 and self._scenario_iter is not None:
+ skipped = await self._skip_scenarios(self._scenario_iter, scenario_offset)
+ self.state.scenario_offset = skipped
+ self.state.total_scenarios_consumed = skipped
+
+ queue_maxsize = (
+ self.queue_maxsize
+ if self.queue_maxsize is not None
+ else max(1, self.max_steps_off_policy * self.max_batch_size)
+ )
+ self._output_queue = asyncio.Queue(maxsize=queue_maxsize)
+ self._eval_queue = asyncio.Queue()
+
+ if self.eval_fn is not None and self.eval_step_0 and start_step == 0:
+ await self._eval_queue.put(start_step)
+ self.state.last_eval_step = start_step
+ self._persist_state(start_step)
+
+ self._status.start(initial_step=start_step)
+ loop = asyncio.get_running_loop()
+ stop_requested = False
+ installed_handlers: list[tuple[str, signal.Signals]] = []
+ original_handlers: dict[signal.Signals, object] = {}
+
+ def _request_stop(sig: signal.Signals) -> None:
+ nonlocal stop_requested
+ if stop_requested:
+ return
+ stop_requested = True
+ print(f"Shutdown requested ({sig.name}); finishing current work...")
+ self.request_stop()
+
+ def _sync_signal_handler(signum: int, _frame: object | None) -> None:
+ _request_stop(signal.Signals(signum))
+
+ if handle_signals:
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ original_handlers[sig] = signal.getsignal(sig)
+ try:
+ loop.add_signal_handler(sig, _request_stop, sig)
+ installed_handlers.append(("loop", sig))
+ except (NotImplementedError, RuntimeError):
+ try:
+ signal.signal(sig, _sync_signal_handler)
+ installed_handlers.append(("signal", sig))
+ except (ValueError, RuntimeError):
+ continue
+ try:
+ async with asyncio.TaskGroup() as tg:
+ tg.create_task(self._rollout_stage(), name="rollout_stage")
+ tg.create_task(self._training_stage(), name="training_stage")
+ tg.create_task(self._eval_stage(), name="eval_stage")
+ tg.create_task(self._status_loop(), name="status_loop")
+ except* Exception as eg:
+ for exc in eg.exceptions:
+ if not isinstance(exc, asyncio.CancelledError):
+ print(f"Pipeline stage failed: {exc}")
+ raise
+ finally:
+ if handle_signals:
+ for mode, sig in installed_handlers:
+ if mode == "loop":
+ try:
+ loop.remove_signal_handler(sig)
+ except (NotImplementedError, RuntimeError):
+ pass
+ try:
+ previous = original_handlers.get(sig)
+ if previous is not None:
+ signal.signal(sig, cast(signal.Handlers, previous))
+ except (ValueError, RuntimeError):
+ pass
+ self._status.flush()
+ self._status.close()
+
+ def request_stop(self) -> None:
+ """Request a clean shutdown of the pipeline stages."""
+ if self.state.done:
+ return
+ self.state.done = True
+
+ async def _notify_policy() -> None:
+ async with self.state.policy_updated:
+ self.state.policy_updated.notify_all()
+
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is None:
+ return
+
+ loop.create_task(_notify_policy())
+ if self._output_queue is not None:
+ try:
+ self._output_queue.put_nowait(None)
+ except asyncio.QueueFull:
+ loop.create_task(self._output_queue.put(None))
+
+ async def _skip_scenarios(
+ self, scenarios: AsyncIterator[ScenarioT], count: int
+ ) -> int:
+ skipped = 0
+ while skipped < count:
+ try:
+ await anext(scenarios)
+ except StopAsyncIteration:
+ break
+ skipped += 1
+ if skipped < count:
+ print(
+ f"Warning: scenario iterator exhausted early while skipping "
+ f"(skipped {skipped}/{count})."
+ )
+ return skipped
+
+ async def _get_next_scenario(self) -> ScenarioT | None:
+ if self._scenario_iter is None:
+ return None
+ async with self._scenario_lock:
+ try:
+ scenario = await anext(self._scenario_iter)
+ except StopAsyncIteration:
+ return None
+ self.state.scenario_offset += 1
+ self.state.total_scenarios_consumed += 1
+ return scenario
+
+ async def _wait_for_policy(self) -> None:
+ async with self.state.policy_updated:
+ while (
+ not self.state.done
+ and self.state.policy_version
+ < self.state.next_training_step - self.max_steps_off_policy
+ ):
+ await self.state.policy_updated.wait()
+
+ async def _rollout_worker(self, worker_id: int) -> None:
+ assert self._output_queue is not None
+ while not self.state.done:
+ scenario = await self._get_next_scenario()
+ if scenario is None:
+ break
+ self._status.note_rollout_started()
+ errored = False
+ try:
+ await self._wait_for_policy()
+ if self.state.done:
+ break
+
+ initial_version = self.state.policy_version
+
+ group = await self.rollout_fn(self.model, scenario, self.config)
+ if not isinstance(group, TrajectoryGroup):
+ errored = True
+ continue
+ self._apply_scenario_metadata(group, scenario)
+ self._apply_policy_versions(
+ group,
+ initial_version=initial_version,
+ final_version=self.state.policy_version,
+ )
+ if self.state.done:
+ break
+ await self._put_output_group(group)
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ errored = True
+ print(f"Worker {worker_id}: rollout failed: {exc}")
+ finally:
+ self._status.note_rollout_finished(errored=errored)
+
+ async def _rollout_stage(self) -> None:
+ async with asyncio.TaskGroup() as tg:
+ for i in range(self.num_rollout_workers):
+ tg.create_task(self._rollout_worker(i))
+ if not self.state.done and self._output_queue is not None:
+ try:
+ self._output_queue.put_nowait(None)
+ except asyncio.QueueFull:
+ await self._output_queue.put(None)
+
+ async def _training_stage(self) -> None:
+ if self._output_queue is None:
+ return
+
+ current_step = self.state.next_training_step
+ stop_at_step = (
+ current_step + self.max_steps if self.max_steps is not None else None
+ )
+ if stop_at_step is not None and current_step >= stop_at_step:
+ self.state.done = True
+ self._persist_state(current_step)
+ async with self.state.policy_updated:
+ self.state.policy_updated.notify_all()
+ return
+ stop_after_batch = False
+
+ while True:
+ if stop_at_step is not None and current_step >= stop_at_step:
+ break
+ step_start = time.monotonic()
+ batch, discarded, saw_sentinel = await self._collect_batch(current_step)
+ self.state.discarded_stale_samples += discarded
+ if discarded:
+ self._status.note_stale(discarded)
+ if not batch:
+ break
+
+ expected_step = current_step + 1
+ should_eval_step = self._should_eval_step(expected_step)
+ should_checkpoint = self.save_checkpoint and should_eval_step
+
+ async with self.state.policy_updated:
+ self.state.next_training_step = expected_step
+ self.state.policy_updated.notify_all()
+
+ self._status.note_training_start(len(batch))
+ train_call_start: float | None = None
+ if os.getenv("ART_TRAIN_STEP_LOG"):
+ print(f"[train] step {expected_step} starting (batch={len(batch)})")
+ train_call_start = time.perf_counter()
+ try:
+ result = await self.backend.train(
+ self.model,
+ batch,
+ learning_rate=self.learning_rate,
+ loss_fn=self.loss_fn,
+ loss_fn_config=self.loss_fn_config,
+ normalize_advantages=self.normalize_advantages,
+ save_checkpoint=should_checkpoint,
+ adam_params=self.adam_params,
+ )
+ except Exception:
+ self._status.note_training_end()
+ raise
+ finally:
+ if train_call_start is not None:
+ train_call_elapsed = time.perf_counter() - train_call_start
+ print(
+ f"[train] step {expected_step} done in "
+ f"{train_call_elapsed:.1f}s"
+ )
+
+ try:
+ current_step = result.step
+ self.state.policy_version = current_step
+ self.state.next_training_step = current_step
+
+ step_seconds = time.monotonic() - step_start
+ self._status.note_training_batch(
+ batch, step=current_step, step_seconds=step_seconds
+ )
+
+ steps_off_policy = self._average_steps_off_policy(current_step, batch)
+ metrics = {
+ "discarded_stale_samples": float(
+ self.state.discarded_stale_samples
+ ),
+ "steps_off_policy": steps_off_policy,
+ "num_groups": float(len(batch)),
+ }
+ metrics.update(result.metrics)
+
+ await self.model.log(
+ batch,
+ split="train",
+ step=current_step,
+ metrics=metrics,
+ )
+ await self._log_discarded_groups(current_step)
+
+ if self.eval_fn is not None and should_eval_step:
+ if self._eval_queue is not None:
+ await self._eval_queue.put(current_step)
+ self.state.last_eval_step = current_step
+
+ self._persist_state(current_step)
+ finally:
+ self._status.note_training_end()
+
+ async with self.state.policy_updated:
+ self.state.policy_updated.notify_all()
+
+ if saw_sentinel:
+ stop_after_batch = True
+ if stop_after_batch:
+ break
+
+ self.state.done = True
+ self._persist_state(current_step)
+ async with self.state.policy_updated:
+ self.state.policy_updated.notify_all()
+
+ async def _collect_batch(
+ self, current_step: int
+ ) -> tuple[list[TrajectoryGroup], int, bool]:
+ assert self._output_queue is not None
+ batch: list[TrajectoryGroup] = []
+ discarded = 0
+ saw_sentinel = False
+ min_version = current_step - self.max_steps_off_policy
+
+ while len(batch) < self.min_batch_size:
+ item = await self._output_queue.get()
+ if item is None:
+ saw_sentinel = True
+ break
+ self._status.note_group_dequeued(item)
+ self._check_all_failed(item)
+ if self._is_group_stale(item, min_version):
+ discarded += 1
+ continue
+ if self._group_zero_variance(item):
+ if self._record_zero_variance(item):
+ return [], discarded, saw_sentinel
+ continue
+ batch.append(item)
+
+ while not saw_sentinel:
+ try:
+ item = self._output_queue.get_nowait()
+ except asyncio.QueueEmpty:
+ break
+ if item is None:
+ saw_sentinel = True
+ break
+ self._status.note_group_dequeued(item)
+ self._check_all_failed(item)
+ if self._is_group_stale(item, min_version):
+ discarded += 1
+ continue
+ if self._group_zero_variance(item):
+ if self._record_zero_variance(item):
+ return [], discarded, saw_sentinel
+ continue
+ batch.append(item)
+
+ return batch, discarded, saw_sentinel
+
+ def _check_all_failed(self, group: TrajectoryGroup) -> None:
+ """Raise if all rollouts in a group failed with exceptions."""
+ if not group.trajectories and group.exceptions:
+ first_exc = group.exceptions[0]
+ raise RuntimeError(
+ f"All {len(group.exceptions)} rollouts in group failed. "
+ f"First exception ({first_exc.type}): {first_exc.message}"
+ )
+
+ async def _eval_stage(self) -> None:
+ if self.eval_fn is None or self._eval_queue is None:
+ return
+
+ pending_eval: asyncio.Task[None] | None = None
+ while not self.state.done or not self._eval_queue.empty():
+ try:
+ step = await asyncio.wait_for(self._eval_queue.get(), timeout=1.0)
+ except asyncio.TimeoutError:
+ continue
+
+ if pending_eval is not None and not pending_eval.done():
+ try:
+ await pending_eval
+ except Exception as exc:
+ print(f"Warning: previous eval failed: {exc}")
+
+ pending_eval = asyncio.create_task(self._run_eval(step))
+
+ if pending_eval is not None and not pending_eval.done():
+ try:
+ await pending_eval
+ except Exception as exc:
+ print(f"Warning: final eval failed: {exc}")
+
+ async def _status_loop(self) -> None:
+ sleep_seconds = min(1.0, max(0.2, self._status_log_interval_seconds / 10))
+ while not self.state.done:
+ self._status.log_if_due()
+ await asyncio.sleep(sleep_seconds)
+
+ async def _run_eval(self, step: int) -> None:
+ assert self.eval_fn is not None
+ self._status.note_val_started(step)
+ reward: float | None = None
+ try:
+ result = await self.eval_fn(self.model, step, self.config)
+ splits: dict[str, list[art.Trajectory]]
+ if isinstance(result, dict):
+ splits = result
+ else:
+ splits = {"val": result}
+
+ val_trajectories = splits.get("val") or []
+ if val_trajectories:
+ reward = sum(t.reward for t in val_trajectories) / len(val_trajectories)
+ else:
+ reward = None
+
+ for split_name, trajectories in splits.items():
+ if trajectories:
+ await self.model.log(trajectories, split=split_name, step=step)
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ print(f"Eval failed at step {step}: {exc}")
+ finally:
+ self._status.note_val_finished(step, reward)
+
+ def _apply_policy_versions(
+ self,
+ group: TrajectoryGroup,
+ *,
+ initial_version: int,
+ final_version: int,
+ ) -> None:
+ for trajectory in group.trajectories:
+ if trajectory.initial_policy_version is None:
+ trajectory.initial_policy_version = initial_version
+ if trajectory.final_policy_version is None:
+ trajectory.final_policy_version = final_version
+
+ def _apply_scenario_metadata(
+ self, group: TrajectoryGroup, scenario: ScenarioT
+ ) -> None:
+ metadata = scenario.get("metadata") if isinstance(scenario, dict) else None
+ if metadata is None or not isinstance(metadata, dict):
+ return
+
+ for key, value in metadata.items():
+ if not isinstance(key, str):
+ continue
+ if not self._is_scalar_metadata(value):
+ continue
+ for trajectory in group.trajectories:
+ trajectory.metadata[f"scenario_{key}"] = value
+
+ def _is_group_stale(self, group: TrajectoryGroup, min_version: int) -> bool:
+ group_version = self._group_initial_version(group)
+ if group_version is None:
+ return False
+ return group_version < min_version
+
+ def _record_zero_variance(self, group: TrajectoryGroup) -> bool:
+ self._discard_queue.append(group)
+ self._status.note_zero_variance_discarded(1)
+ if len(self._discard_queue) >= self._discard_queue_limit:
+ self._trigger_collapse()
+ return True
+ return False
+
+ def _trigger_collapse(self) -> None:
+ if self._collapse_triggered:
+ return
+ self._collapse_triggered = True
+ self.state.done = True
+ print(
+ "\n"
+ "========================================\n"
+ "MODEL COLLAPSE DETECTED - Training stopped\n"
+ "========================================\n"
+ "\n"
+ f"Too many trajectory groups ({self._discard_queue_limit}) had zero reward variance,\n"
+ "indicating the model may have collapsed to a degenerate policy.\n"
+ "\n"
+ "To improve training dynamics:\n"
+ " - Lower the learning rate to reduce instability\n"
+ " - Ensure your reward function provides meaningful variance\n"
+ " - Check that prompts are diverse enough to elicit different responses\n"
+ " - Consider using a smaller batch size for more frequent updates\n"
+ "\n"
+ "To disable this failsafe:\n"
+ " - Increase `discard_queue_multiplier` (currently triggers after\n"
+ f" {self.discard_queue_multiplier} * min_batch_size = {self._discard_queue_limit} zero-variance groups)\n"
+ "\n"
+ )
+
+ async def _log_discarded_groups(self, step: int) -> None:
+ if not self._discard_queue:
+ return
+ discarded = list(self._discard_queue)
+ await self.model.log(discarded, split="discarded", step=step)
+ self._discard_queue.clear()
+
+ @staticmethod
+ def _group_zero_variance(group: TrajectoryGroup) -> bool:
+ rewards = [t.reward for t in group.trajectories]
+ if len(rewards) <= 1:
+ return True
+ first = rewards[0]
+ return all(abs(r - first) <= 1e-12 for r in rewards[1:])
+
+ def _group_initial_version(self, group: TrajectoryGroup) -> int | None:
+ versions = [
+ trajectory.initial_policy_version
+ for trajectory in group.trajectories
+ if trajectory.initial_policy_version is not None
+ ]
+ if not versions:
+ return None
+ return min(versions)
+
+ def _average_steps_off_policy(
+ self, current_step: int, batch: list[TrajectoryGroup]
+ ) -> float:
+ steps: list[int] = []
+ for group in batch:
+ group_version = self._group_initial_version(group)
+ if group_version is None:
+ continue
+ steps.append(current_step - group_version)
+ if not steps:
+ return 0.0
+ return sum(steps) / len(steps)
+
+ def _should_eval_step(self, step: int) -> bool:
+ if self.eval_fn is None:
+ return False
+ if self.eval_every_n_steps <= 0:
+ return False
+ return (step - self.state.last_eval_step) >= self.eval_every_n_steps
+
+ def _read_pipeline_state(self) -> dict[str, Any]:
+ state = self.model.read_state() or {}
+ return state.get(PIPELINE_STATE_KEY, {})
+
+ def _persist_state(self, training_step: int) -> None:
+ payload = {
+ "scenario_offset": self.state.scenario_offset,
+ "total_scenarios_consumed": self.state.total_scenarios_consumed,
+ "training_step": training_step,
+ "last_eval_step": self.state.last_eval_step,
+ }
+ self.model.merge_state({PIPELINE_STATE_KEY: payload})
+
+ @staticmethod
+ def _is_scalar_metadata(value: object) -> bool:
+ return value is None or isinstance(value, (str, int, float, bool))
+
+ async def _put_output_group(self, group: TrajectoryGroup) -> None:
+ assert self._output_queue is not None
+ while not self.state.done:
+ try:
+ await asyncio.wait_for(self._output_queue.put(group), timeout=1.0)
+ self._status.note_group_enqueued(group)
+ return
+ except asyncio.TimeoutError:
+ continue
diff --git a/src/art/pipeline_trainer/types.py b/src/art/pipeline_trainer/types.py
new file mode 100644
index 000000000..532acf9cd
--- /dev/null
+++ b/src/art/pipeline_trainer/types.py
@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from collections.abc import Awaitable, Callable
+from typing import TypeVar
+
+import art
+from art import Trajectory, TrajectoryGroup
+
+ScenarioT = TypeVar("ScenarioT", bound=dict)
+ConfigT = TypeVar("ConfigT")
+ScalarMetadataValue = float | int | str | bool | None
+
+
+RolloutFn = Callable[
+ [art.TrainableModel, ScenarioT, ConfigT], Awaitable[TrajectoryGroup]
+]
+
+SingleRolloutFn = Callable[
+ [art.TrainableModel, ScenarioT, ConfigT], Awaitable[Trajectory]
+]
+
+EvalFn = Callable[
+ [art.TrainableModel, int, ConfigT],
+ Awaitable[list[Trajectory] | dict[str, list[Trajectory]]],
+]
diff --git a/src/art/pipeline_trainer/yes_no_maybe_pipeline.py b/src/art/pipeline_trainer/yes_no_maybe_pipeline.py
new file mode 100644
index 000000000..63e3323dd
--- /dev/null
+++ b/src/art/pipeline_trainer/yes_no_maybe_pipeline.py
@@ -0,0 +1,149 @@
+"""Minimal yes/no/maybe RL training example using PipelineTrainer."""
+
+from __future__ import annotations
+
+import asyncio
+from datetime import datetime
+from functools import partial
+from itertools import cycle, permutations
+import os
+import re
+
+from dotenv import load_dotenv
+
+import art
+
+from . import PipelineTrainer
+
+# Training config
+BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" # or Qwen/Qwen3-4B-Instruct-2507
+MODEL_NAME = "pipeline-yes-no-maybe"
+PROJECT = "yes-no-maybe-pipeline"
+ROLLOUTS_PER_SCENARIO = 32
+MAX_TOKENS = 5
+MAX_STEPS = 20
+EVAL_TRAJECTORY_COUNT = 30
+EVAL_EVERY_N_STEPS = 2
+
+
+def build_scenarios() -> list[dict]:
+ """Generate all scenario variations."""
+ scenarios: list[dict] = []
+ for prefix in ["respond", "just respond"]:
+ for use_quotes in [True, False]:
+ for n in [3, 2]:
+ for words in permutations(["yes", "no", "maybe"], n):
+ quoted = [f"'{w}'" if use_quotes else w for w in words]
+ if len(words) == 3:
+ body = ", ".join(quoted)
+ else:
+ body = " or ".join(quoted)
+ scenarios.append({"prompt": f"{prefix} with {body}"})
+ return scenarios
+
+
+def reward_for_answer(text: str) -> float:
+ """Score: maybe=1.0, no=0.75, yes=0.5, other=0.0."""
+ if not text:
+ return 0.0
+ first_word = re.split(r"\s+", text.strip().lower())[0].strip(".,!?:;\"'()[]{}")
+ return {"maybe": 1.0, "no": 0.75, "yes": 0.5}.get(first_word, 0.0)
+
+
+async def eval_fn(
+ model: art.TrainableModel,
+ step: int,
+ _config: None,
+ *,
+ scenarios: list[dict],
+) -> list[art.Trajectory]:
+ trajectories: list[art.Trajectory] = []
+ openai_client = model.openai_client()
+ model_name = model.get_inference_name(step)
+ for scenario in scenarios:
+ messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}]
+ response = await openai_client.chat.completions.create(
+ messages=messages,
+ model=model_name,
+ max_tokens=MAX_TOKENS,
+ n=1,
+ )
+ choice = response.choices[0]
+ trajectories.append(
+ art.Trajectory(
+ messages_and_choices=[*messages, choice],
+ reward=reward_for_answer(choice.message.content or ""),
+ )
+ )
+ return trajectories
+
+
+async def rollout_fn(model, scenario, _config) -> art.TrajectoryGroup:
+ """Single inference call returns N completions for the group."""
+ messages: art.Messages = [{"role": "user", "content": scenario["prompt"]}]
+ response = await model.openai_client().chat.completions.create(
+ messages=messages,
+ model=model.get_inference_name(),
+ max_tokens=MAX_TOKENS,
+ n=ROLLOUTS_PER_SCENARIO,
+ )
+ return art.TrajectoryGroup(
+ [
+ art.Trajectory(
+ messages_and_choices=[*messages, choice],
+ reward=reward_for_answer(choice.message.content or ""),
+ )
+ for choice in response.choices
+ ]
+ )
+
+
+async def main() -> None:
+ load_dotenv()
+ if not os.environ.get("TINKER_API_KEY"):
+ raise RuntimeError("TINKER_API_KEY environment variable is required")
+
+ model_name = f"{MODEL_NAME}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
+
+ print("Initializing TinkerNativeBackend")
+ backend = art.TinkerNativeBackend()
+
+ print(f"Initializing TrainableModel: {model_name}")
+ model = art.TrainableModel(name=model_name, project=PROJECT, base_model=BASE_MODEL)
+
+ print("Registering model with backend")
+ await model.register(backend)
+ print("Model registered")
+
+ openai_client = model.openai_client()
+ base_scenarios = build_scenarios()
+ scenarios = cycle(base_scenarios)
+ eval_scenarios = base_scenarios[:EVAL_TRAJECTORY_COUNT]
+
+ eval_callback = partial(eval_fn, scenarios=eval_scenarios)
+
+ trainer = PipelineTrainer(
+ model=model,
+ backend=backend,
+ rollout_fn=rollout_fn,
+ scenarios=scenarios,
+ config=None,
+ learning_rate=5e-5,
+ loss_fn="cispo",
+ eval_fn=eval_callback,
+ max_steps=MAX_STEPS,
+ eval_every_n_steps=EVAL_EVERY_N_STEPS,
+ eval_step_0=False,
+ total_scenarios=None,
+ )
+
+ print(
+ f"Training {model_name}: {MAX_STEPS} steps, "
+ f"{len(base_scenarios)} unique scenarios (cycling)"
+ )
+ await trainer.train()
+ await backend.close()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/src/art/tinker_native/__init__.py b/src/art/tinker_native/__init__.py
new file mode 100644
index 000000000..a6dc5bc59
--- /dev/null
+++ b/src/art/tinker_native/__init__.py
@@ -0,0 +1,3 @@
+from .backend import TinkerNativeBackend
+
+__all__ = ["TinkerNativeBackend"]
diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py
new file mode 100644
index 000000000..291621b6c
--- /dev/null
+++ b/src/art/tinker_native/backend.py
@@ -0,0 +1,766 @@
+from __future__ import annotations
+
+import asyncio
+from dataclasses import dataclass
+import os
+import re
+import time
+from typing import Any, Awaitable, Iterable, Literal, TypeVar, cast
+import uuid
+
+from fastapi import FastAPI, HTTPException
+from openai import AsyncOpenAI
+from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
+from openai.types.chat.chat_completion_message import ChatCompletionMessage
+from openai.types.chat.chat_completion_message_function_tool_call import (
+ ChatCompletionMessageFunctionToolCall,
+ Function,
+)
+from openai.types.chat.chat_completion_message_tool_call import (
+ ChatCompletionMessageToolCallUnion,
+)
+from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
+from openai.types.chat.completion_create_params import CompletionCreateParams
+from openai.types.completion_usage import CompletionUsage
+import tinker
+from tinker_cookbook import renderers, tokenizer_utils
+import uvicorn
+
+from .. import dev
+from ..backend import Backend
+from ..model import Model, TrainableModel
+from ..tinker.backend import get_renderer_name
+from ..tinker.server import get_free_port
+from ..trajectories import TrajectoryGroup
+from ..types import TrainResult
+from ..utils.output_dirs import get_model_dir
+from ..utils.trajectory_migration import auto_migrate_on_register
+from .data import (
+ convert_openai_messages_to_renderer_format,
+ parse_completion_to_openai_message,
+ trajectory_groups_to_datums,
+)
+
+STATE_KEY_RUN_IDS = "tinker_run_ids"
+STATE_KEY_LATEST_STEP = "latest_step"
+T = TypeVar("T")
+
+
+@dataclass
+class ModelState:
+ service_client: tinker.ServiceClient
+ rest_client: Any
+ training_client: tinker.TrainingClient
+ sampler_clients: dict[int, tinker.SamplingClient]
+ sampler_checkpoint_paths: dict[int, str]
+ training_checkpoint_paths: dict[int, str]
+ current_step: int
+ renderer: Any
+ tokenizer: Any
+ output_dir: str
+ tinker_run_ids: list[str]
+ model_name: str
+ server_task: asyncio.Task[None] | None = None
+ server_host: str | None = None
+ server_port: int | None = None
+ server_api_key: str | None = None
+
+
+@dataclass
+class TinkerNativeModelConfig:
+ renderer_name: str
+ training_client_args: dict[str, Any]
+
+
+class TinkerNativeBackend(Backend):
+ _tinker_train_log_env = "ART_TINKER_TRAIN_LOG"
+ _tinker_sample_log_env = "ART_TINKER_SAMPLE_LOG"
+
+ def __init__(
+ self,
+ *,
+ tinker_api_key: str | None = None,
+ path: str | None = None,
+ ) -> None:
+ if not "TINKER_API_KEY" in os.environ or tinker_api_key is not None:
+ assert tinker_api_key is not None, (
+ "TINKER_API_KEY is not set and no tinker_api_key was provided"
+ )
+ print("Setting TINKER_API_KEY to", tinker_api_key, "in environment")
+ os.environ["TINKER_API_KEY"] = tinker_api_key
+
+ self._path = path or ".art"
+ os.makedirs(self._path, exist_ok=True)
+ self._model_state: dict[str, ModelState] = {}
+
+ def _env_enabled(self, env_name: str) -> bool:
+ value = os.getenv(env_name)
+ if value is None:
+ return False
+ return value.lower() not in ("", "0", "false", "no")
+
+ def _timestamp(self) -> str:
+ return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+
+ async def _tinker_call(
+ self,
+ label: str,
+ awaitable: Awaitable[T],
+ *,
+ env_name: str,
+ prefix: str,
+ ) -> T:
+ if not self._env_enabled(env_name):
+ return await awaitable
+ start = time.perf_counter()
+ print(f"[tinker:{prefix}] {label} start {self._timestamp()}")
+ try:
+ return await awaitable
+ finally:
+ elapsed = time.perf_counter() - start
+ print(
+ f"[tinker:{prefix}] {label} done in {elapsed:.2f}s "
+ f"at {self._timestamp()}"
+ )
+
+ async def _tinker_train_call(self, label: str, awaitable: Awaitable[T]) -> T:
+ return await self._tinker_call(
+ label,
+ awaitable,
+ env_name=self._tinker_train_log_env,
+ prefix="train",
+ )
+
+ async def _tinker_sample_call(self, label: str, awaitable: Awaitable[T]) -> T:
+ return await self._tinker_call(
+ label,
+ awaitable,
+ env_name=self._tinker_sample_log_env,
+ prefix="sample",
+ )
+
+ async def close(self) -> None:
+ for state in self._model_state.values():
+ if state.server_task is not None:
+ state.server_task.cancel()
+
+ async def register(self, model: Model) -> None:
+ model.base_path = self._path
+ output_dir = get_model_dir(model=model, art_path=self._path)
+ os.makedirs(output_dir, exist_ok=True)
+ with open(f"{output_dir}/model.json", "w") as f:
+ import json
+
+ json.dump(model.model_dump(), f)
+
+ auto_migrate_on_register(output_dir)
+
+ if not model.trainable:
+ return
+ trainable_model = cast(TrainableModel, model)
+ state = await self._build_model_state(trainable_model)
+ self._model_state[model.name] = state
+
+ async def _prepare_backend_for_training(
+ self,
+ model: TrainableModel,
+ config: dev.OpenAIServerConfig | None = None,
+ ) -> tuple[str, str]:
+ state = self._model_state[model.name]
+
+ raw_config: dict[str, Any] = cast(dict[str, Any], config) if config else {}
+ server_args = cast(dict[str, Any], raw_config.get("server_args", {}))
+ host = server_args.get("host", raw_config.get("host", "0.0.0.0"))
+ port = server_args.get("port", raw_config.get("port"))
+ if port is None:
+ port = get_free_port()
+ api_key = server_args.get("api_key", raw_config.get("api_key")) or "default"
+
+ if state.server_task is None:
+ state.server_host = host
+ state.server_port = port
+ state.server_api_key = api_key
+ state.server_task = asyncio.create_task(
+ self._run_openai_server(state, host=host, port=port)
+ )
+ state.server_task.add_done_callback(self._crash_on_server_exit)
+
+ base_url = f"http://{host}:{port}/v1"
+ await self._wait_for_server_ready(base_url, api_key, model)
+ return base_url, api_key
+
+ async def train( # type: ignore[override]
+ self,
+ model: TrainableModel,
+ trajectory_groups: Iterable[TrajectoryGroup],
+ *,
+ learning_rate: float = 1e-5,
+ loss_fn: Literal["cispo", "ppo", "importance_sampling", "dro"] = "cispo",
+ normalize_advantages: bool = True,
+ save_checkpoint: bool = False,
+ loss_fn_config: dict | None = None,
+ adam_params: tinker.AdamParams | None = None,
+ ) -> TrainResult:
+ state = self._model_state[model.name]
+ groups_list = list(trajectory_groups)
+
+ datums = trajectory_groups_to_datums(
+ groups_list,
+ state.renderer,
+ state.tokenizer,
+ normalize_advantages,
+ )
+
+ metrics: dict[str, float] = {
+ "num_groups_submitted": float(len(groups_list)),
+ "num_datums": float(len(datums)),
+ }
+
+ if not datums:
+ return TrainResult(step=state.current_step, metrics=metrics)
+
+ if adam_params is None:
+ adam_params = tinker.AdamParams(
+ learning_rate=learning_rate,
+ beta1=0.9,
+ beta2=0.95,
+ eps=1e-8,
+ )
+
+ def remove_mask(datum: tinker.Datum) -> tinker.Datum:
+ if "mask" not in datum.loss_fn_inputs:
+ return datum
+ loss_fn_inputs = {
+ key: value
+ for key, value in datum.loss_fn_inputs.items()
+ if key != "mask"
+ }
+ return tinker.Datum(
+ model_input=datum.model_input, loss_fn_inputs=loss_fn_inputs
+ )
+
+ forward_output = await self._tinker_train_call(
+ "forward_backward",
+ state.training_client.forward_backward(
+ [remove_mask(datum) for datum in datums],
+ loss_fn=loss_fn,
+ loss_fn_config=loss_fn_config,
+ ),
+ )
+ optim_output = await self._tinker_train_call(
+ "optim_step", state.training_client.optim_step(adam_params)
+ )
+
+ if forward_output.metrics:
+ for key, value in forward_output.metrics.items():
+ if value is None:
+ continue
+ metrics[key] = float(value)
+ if optim_output.metrics:
+ for key, value in optim_output.metrics.items():
+ if value is None:
+ continue
+ metrics[key] = float(value)
+
+ next_step = state.current_step + 1
+ checkpoint_name = f"step_{next_step:06d}"
+
+ if save_checkpoint:
+ state_response, sampler_response = await self._save_checkpoint(
+ state.training_client, checkpoint_name
+ )
+ state.training_checkpoint_paths[next_step] = state_response.path
+ else:
+ sampler_response = await self._save_sampler_weights(
+ state.training_client, checkpoint_name
+ )
+ sampler_client = await self._tinker_train_call(
+ "create_sampling_client_async",
+ state.training_client.create_sampling_client_async(
+ model_path=sampler_response.path
+ ),
+ )
+ state.sampler_clients[next_step] = sampler_client
+ state.sampler_checkpoint_paths[next_step] = sampler_response.path
+
+ state.current_step = next_step
+ self._persist_model_state(model, state)
+
+ return TrainResult(step=state.current_step, metrics=metrics)
+
+ async def _get_step(self, model: TrainableModel) -> int:
+ if model.name in self._model_state:
+ return self._model_state[model.name].current_step
+ state = model.read_state() or {}
+ return int(state.get(STATE_KEY_LATEST_STEP, 0))
+
+ async def _delete_checkpoint_files(
+ self,
+ model: TrainableModel,
+ steps_to_keep: list[int],
+ ) -> None:
+ print("Checkpoint deletion is not yet implemented for TinkerNativeBackend.")
+
+ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
+ base_name = model.inference_model_name or model.name
+ if "@" in base_name:
+ base_name = base_name.split("@", 1)[0]
+ if step is None:
+ state = self._model_state.get(model.name)
+ step = state.current_step if state is not None else 0
+ return f"{base_name}@{step}"
+
+ async def _run_openai_server(
+ self,
+ state: ModelState,
+ host: str,
+ port: int,
+ ) -> None:
+ app = FastAPI()
+
+ @app.post("/v1/chat/completions")
+ async def chat_completions(body: CompletionCreateParams) -> ChatCompletion:
+ model_name = body.get("model")
+ _, step = self._parse_model_name(model_name)
+ sampler_client = await self._get_sampler_client(state, step)
+
+ messages = self._normalize_messages(body["messages"])
+ tools = self._normalize_tools(body.get("tools"))
+ renderer_messages = convert_openai_messages_to_renderer_format(
+ messages=messages,
+ tools=tools,
+ renderer=state.renderer,
+ )
+ prompt = state.renderer.build_generation_prompt(renderer_messages)
+ prompt_tokens = list(prompt.to_ints())
+
+ max_tokens = body.get("max_completion_tokens")
+ if max_tokens is None:
+ max_tokens = body.get("max_tokens")
+ temperature = body.get("temperature")
+ top_k = body.get("top_k")
+ top_p = body.get("top_p")
+ sampling_params = tinker.SamplingParams(
+ max_tokens=max_tokens,
+ seed=body.get("seed"),
+ temperature=temperature if temperature is not None else 1.0,
+ top_k=top_k if top_k is not None else -1,
+ top_p=top_p if top_p is not None else 1.0,
+ stop=state.renderer.get_stop_sequences(),
+ )
+ sample_response = await self._tinker_sample_call(
+ "sample_async",
+ sampler_client.sample_async(
+ prompt=prompt,
+ num_samples=body.get("n") or 1,
+ sampling_params=sampling_params,
+ ),
+ )
+
+ choices: list[Choice] = []
+ for i, sequence in enumerate(sample_response.sequences):
+ if sequence.logprobs is None:
+ raise HTTPException(status_code=400, detail="Logprobs are required")
+ if len(sequence.tokens) != len(sequence.logprobs):
+ raise HTTPException(
+ status_code=400,
+ detail="Tokens and logprobs must have the same length",
+ )
+ parsed_message = parse_completion_to_openai_message(
+ list(sequence.tokens), state.renderer
+ )
+ tool_calls: list[ChatCompletionMessageToolCallUnion] | None = None
+ if parsed_message.get("tool_calls"):
+ tool_calls = [
+ ChatCompletionMessageFunctionToolCall(
+ type="function",
+ id=tool_call["id"],
+ function=Function(
+ name=tool_call["function"]["name"],
+ arguments=tool_call["function"]["arguments"],
+ ),
+ )
+ for tool_call in parsed_message["tool_calls"]
+ ]
+ choices.append(
+ Choice(
+ finish_reason=sequence.stop_reason,
+ index=i,
+ message=ChatCompletionMessage(
+ content=parsed_message.get("content", ""),
+ role="assistant",
+ tool_calls=tool_calls,
+ ),
+ logprobs=ChoiceLogprobs(
+ content=[
+ ChatCompletionTokenLogprob(
+ token=f"token_id:{token}",
+ logprob=logprob,
+ top_logprobs=[],
+ )
+ for token, logprob in zip(
+ sequence.tokens, sequence.logprobs
+ )
+ ]
+ ),
+ )
+ )
+
+ completion_tokens = sum(
+ len(sequence.tokens) for sequence in sample_response.sequences
+ )
+ return ChatCompletion(
+ id=str(uuid.uuid4()),
+ choices=choices,
+ created=int(time.time()),
+ model=self._format_response_model(model_name, step, state),
+ object="chat.completion",
+ usage=CompletionUsage(
+ completion_tokens=completion_tokens,
+ prompt_tokens=len(prompt_tokens),
+ total_tokens=completion_tokens + len(prompt_tokens),
+ ),
+ )
+
+ server_config = uvicorn.Config(app, host=host, port=port, log_level="error")
+ server = uvicorn.Server(server_config)
+ await server.serve()
+
+ def _crash_on_server_exit(self, task: asyncio.Task[None]) -> None:
+ try:
+ task.result()
+ except asyncio.CancelledError:
+ return
+ except Exception as exc:
+ print(f"OpenAI server crashed: {exc}")
+ else:
+ print("OpenAI server exited unexpectedly.")
+ os._exit(1)
+
+ async def _wait_for_server_ready(
+ self, base_url: str, api_key: str, model: TrainableModel
+ ) -> None:
+ client = AsyncOpenAI(base_url=base_url, api_key=api_key)
+ with_timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 300.0))
+ start = time.time()
+ while True:
+ if time.time() - start > with_timeout:
+ raise TimeoutError(
+ f"Unable to reach OpenAI-compatible server within {with_timeout} seconds."
+ )
+ try:
+ await client.chat.completions.create(
+ model=self._model_inference_name(model),
+ messages=[{"role": "user", "content": "Hello, world!"}],
+ max_completion_tokens=1,
+ )
+ return
+ except Exception:
+ await asyncio.sleep(0.1)
+
+ async def _build_model_state(self, model: TrainableModel) -> ModelState:
+ config = self._resolve_model_config(model)
+ service_client = tinker.ServiceClient()
+ rest_client = service_client.create_rest_client()
+
+ tokenizer = tokenizer_utils.get_tokenizer(model.base_model)
+ renderer = renderers.get_renderer(
+ name=config.renderer_name,
+ tokenizer=tokenizer,
+ )
+
+ saved_state = model.read_state() or {}
+ tinker_run_ids = list(saved_state.get(STATE_KEY_RUN_IDS, []))
+ training_paths, sampler_paths = await self._list_checkpoints(
+ rest_client, tinker_run_ids
+ )
+
+ training_client: tinker.TrainingClient
+ current_step = 0
+
+ if training_paths:
+ current_step = max(training_paths.keys())
+ checkpoint_path = training_paths[current_step]
+ training_client = await self._create_training_client_from_checkpoint(
+ service_client=service_client,
+ checkpoint_state_path=checkpoint_path,
+ base_model=model.base_model,
+ training_client_args=config.training_client_args,
+ reset_optimizer=False,
+ )
+ else:
+ training_client = await self._tinker_train_call(
+ "create_lora_training_client_async",
+ service_client.create_lora_training_client_async(
+ model.base_model, **config.training_client_args
+ ),
+ )
+ checkpoint_name = f"step_{current_step:06d}"
+ sampler_response = await self._save_sampler_weights(
+ training_client, checkpoint_name
+ )
+ sampler_paths[current_step] = sampler_response.path
+
+ run_id = training_client.model_id
+ if run_id not in tinker_run_ids:
+ tinker_run_ids.append(run_id)
+
+ sampler_clients: dict[int, tinker.SamplingClient] = {}
+ if current_step in sampler_paths:
+ sampler_clients[current_step] = await self._tinker_train_call(
+ "create_sampling_client_async",
+ training_client.create_sampling_client_async(
+ model_path=sampler_paths[current_step]
+ ),
+ )
+ else:
+ checkpoint_name = f"step_{current_step:06d}"
+ sampler_response = await self._save_sampler_weights(
+ training_client, checkpoint_name
+ )
+ sampler_paths[current_step] = sampler_response.path
+ sampler_clients[current_step] = await self._tinker_train_call(
+ "create_sampling_client_async",
+ training_client.create_sampling_client_async(
+ model_path=sampler_response.path
+ ),
+ )
+
+ state = ModelState(
+ service_client=service_client,
+ rest_client=rest_client,
+ training_client=training_client,
+ sampler_clients=sampler_clients,
+ sampler_checkpoint_paths=sampler_paths,
+ training_checkpoint_paths=training_paths,
+ current_step=current_step,
+ renderer=renderer,
+ tokenizer=tokenizer,
+ output_dir=get_model_dir(model=model, art_path=self._path),
+ tinker_run_ids=tinker_run_ids,
+ model_name=((model.inference_model_name or model.name).split("@", 1)[0]),
+ )
+ self._persist_model_state(model, state)
+ return state
+
+ def _resolve_model_config(self, model: TrainableModel) -> TinkerNativeModelConfig:
+ internal_config = model._internal_config or {}
+ tinker_native_args = cast(
+ dev.TinkerNativeArgs | None,
+ internal_config.get("tinker_native_args"),
+ )
+ renderer_name = (
+ tinker_native_args.get("renderer_name")
+ if tinker_native_args is not None
+ else None
+ )
+ if renderer_name is None:
+ renderer_name = get_renderer_name(model.base_model)
+
+ training_client_args = dict(
+ tinker_native_args.get("training_client_args", {})
+ if tinker_native_args is not None
+ else {}
+ )
+ if "rank" not in training_client_args:
+ training_client_args["rank"] = 8
+ if "train_unembed" not in training_client_args:
+ training_client_args["train_unembed"] = False
+
+ return TinkerNativeModelConfig(
+ renderer_name=renderer_name,
+ training_client_args=training_client_args,
+ )
+
+ async def _list_checkpoints(
+ self, rest_client: Any, tinker_run_ids: list[str]
+ ) -> tuple[dict[int, str], dict[int, str]]:
+ training_paths: dict[int, str] = {}
+ sampler_paths: dict[int, str] = {}
+ step_pattern = re.compile(r"(?:weights/)?step_(\d+)$")
+
+ for run_id in tinker_run_ids:
+ try:
+ response = await self._tinker_train_call(
+ f"list_checkpoints_async {run_id}",
+ rest_client.list_checkpoints_async(run_id),
+ )
+ except Exception as exc:
+ print(f"Warning: Could not list checkpoints for {run_id}: {exc}")
+ continue
+ for checkpoint in response.checkpoints:
+ match = step_pattern.match(checkpoint.checkpoint_id)
+ if not match:
+ continue
+ step = int(match.group(1))
+ path = f"tinker://{run_id}/{checkpoint.checkpoint_id}"
+ if checkpoint.checkpoint_type == "training":
+ training_paths[step] = path
+ elif checkpoint.checkpoint_type == "sampler":
+ sampler_paths[step] = path
+ return training_paths, sampler_paths
+
+ async def _get_sampler_client(
+ self, state: ModelState, step: int | None
+ ) -> tinker.SamplingClient:
+ actual_step = step if step is not None else state.current_step
+ if actual_step in state.sampler_clients:
+ return state.sampler_clients[actual_step]
+
+ if actual_step not in state.sampler_checkpoint_paths:
+ training_paths, sampler_paths = await self._list_checkpoints(
+ state.rest_client, state.tinker_run_ids
+ )
+ state.training_checkpoint_paths.update(training_paths)
+ state.sampler_checkpoint_paths.update(sampler_paths)
+
+ if actual_step not in state.sampler_checkpoint_paths:
+ available = sorted(state.sampler_checkpoint_paths.keys())
+ raise HTTPException(
+ status_code=404,
+ detail=f"No sampler checkpoint for step {actual_step}. Available: {available}",
+ )
+
+ sampler_client = await self._tinker_train_call(
+ "create_sampling_client_async",
+ state.training_client.create_sampling_client_async(
+ model_path=state.sampler_checkpoint_paths[actual_step]
+ ),
+ )
+ state.sampler_clients[actual_step] = sampler_client
+ return sampler_client
+
+ def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]:
+ normalized: list[dict[str, Any]] = []
+ for message in messages:
+ if hasattr(message, "model_dump"):
+ normalized.append(message.model_dump())
+ else:
+ normalized.append(dict(message))
+ return normalized
+
+ def _normalize_tools(
+ self, tools: Iterable[Any] | None
+ ) -> list[dict[str, Any]] | None:
+ if tools is None:
+ return None
+ normalized: list[dict[str, Any]] = []
+ for tool in tools:
+ if hasattr(tool, "model_dump"):
+ normalized.append(tool.model_dump())
+ else:
+ normalized.append(dict(tool))
+ return normalized
+
+ def _parse_model_name(
+ self, model_name: str | None
+ ) -> tuple[str | None, int | None]:
+ if model_name and "@" in model_name:
+ base_name, step_str = model_name.rsplit("@", 1)
+ try:
+ return base_name, int(step_str)
+ except ValueError as exc:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid model step: {model_name}"
+ ) from exc
+ return model_name, None
+
+ def _format_response_model(
+ self, model_name: str | None, step: int | None, state: ModelState
+ ) -> str:
+ if model_name is None:
+ return f"{state.model_name}@{state.current_step}"
+ if step is None and "@" not in model_name:
+ return f"{model_name}@{state.current_step}"
+ return model_name
+
+ async def _create_training_client_from_checkpoint(
+ self,
+ service_client: tinker.ServiceClient,
+ checkpoint_state_path: str,
+ base_model: str,
+ training_client_args: dict[str, Any],
+ reset_optimizer: bool = False,
+ ) -> tinker.TrainingClient:
+ training_client = await self._tinker_train_call(
+ "create_lora_training_client_async",
+ service_client.create_lora_training_client_async(
+ base_model, **training_client_args
+ ),
+ )
+
+ if reset_optimizer:
+ load_future = await self._tinker_train_call(
+ "load_state_async",
+ training_client.load_state_async(checkpoint_state_path),
+ )
+ await self._tinker_train_call(
+ "load_state_result_async", load_future.result_async()
+ )
+ else:
+ load_future = await self._tinker_train_call(
+ "load_state_with_optimizer_async",
+ training_client.load_state_with_optimizer_async(checkpoint_state_path),
+ )
+ await self._tinker_train_call(
+ "load_state_with_optimizer_result_async", load_future.result_async()
+ )
+
+ return training_client
+
+ async def _save_checkpoint(
+ self,
+ training_client: tinker.TrainingClient,
+ checkpoint_name: str,
+ ) -> tuple[Any, Any]:
+ state_future, sampler_future = await asyncio.gather(
+ self._tinker_train_call(
+ "save_state_async",
+ training_client.save_state_async(checkpoint_name),
+ ),
+ self._tinker_train_call(
+ "save_weights_for_sampler_async",
+ training_client.save_weights_for_sampler_async(checkpoint_name),
+ ),
+ )
+ state_result = await self._tinker_train_call(
+ "save_state_result_async", state_future.result_async()
+ )
+ sampler_result = await self._tinker_train_call(
+ "save_weights_for_sampler_result_async", sampler_future.result_async()
+ )
+ return state_result, sampler_result
+
+ async def _save_sampler_weights(
+ self,
+ training_client: tinker.TrainingClient,
+ checkpoint_name: str,
+ ) -> Any:
+ sampler_future = await self._tinker_train_call(
+ "save_weights_for_sampler_async",
+ training_client.save_weights_for_sampler_async(checkpoint_name),
+ )
+ return await self._tinker_train_call(
+ "save_weights_for_sampler_result_async", sampler_future.result_async()
+ )
+
+ async def _save_training_state(
+ self,
+ training_client: tinker.TrainingClient,
+ checkpoint_name: str,
+ ) -> Any:
+ state_future = await self._tinker_train_call(
+ "save_state_async",
+ training_client.save_state_async(checkpoint_name),
+ )
+ return await self._tinker_train_call(
+ "save_state_result_async", state_future.result_async()
+ )
+
+ def _persist_model_state(self, model: TrainableModel, state: ModelState) -> None:
+ model.merge_state(
+ {
+ STATE_KEY_RUN_IDS: state.tinker_run_ids,
+ STATE_KEY_LATEST_STEP: state.current_step,
+ }
+ )
diff --git a/src/art/tinker_native/data.py b/src/art/tinker_native/data.py
new file mode 100644
index 000000000..994d6e39a
--- /dev/null
+++ b/src/art/tinker_native/data.py
@@ -0,0 +1,382 @@
+from __future__ import annotations
+
+import json
+import re
+from typing import Any, Iterable, cast
+
+from openai.types.chat.chat_completion import Choice
+import tinker
+from tinker_cookbook import renderers
+import torch
+
+from ..trajectories import History, Trajectory, TrajectoryGroup, get_messages
+from ..types import MessagesAndChoices
+
+
+def _create_conversation_prefix_with_tools_fallback(
+ tools: list[dict[str, Any]], system_prompt: str = ""
+) -> list[dict[str, Any]]:
+ """Fallback implementation for create_conversation_prefix_with_tools.
+
+ Used when the installed tinker_cookbook version doesn't have this method.
+ Implements the Qwen3 tool format.
+ """
+ tools_text = ""
+ if tools:
+ # Each tool is wrapped in {"type": "function", "function": {...}} per OpenAI format
+ tool_lines = "\n".join(
+ json.dumps({"type": "function", "function": tool}, separators=(", ", ": "))
+ for tool in tools
+ )
+ tools_text = f"""# Tools
+
+You may call one or more functions to assist with the user query.
+
+You are provided with function signatures within XML tags:
+
+{tool_lines}
+
+
+For each function call, return a json object with function name and arguments within XML tags:
+
+{{"name": , "arguments": }}
+"""
+
+ # Add separator between system prompt and tools if system prompt exists
+ if system_prompt:
+ content = system_prompt + "\n\n" + tools_text
+ else:
+ content = tools_text
+
+ return [{"role": "system", "content": content}]
+
+
+def create_conversation_prefix_with_tools(
+ renderer: Any, tools: list[dict[str, Any]], system_prompt: str = ""
+) -> list[dict[str, Any]]:
+ """Create conversation prefix with tools, using renderer method or fallback."""
+ if hasattr(renderer, "create_conversation_prefix_with_tools"):
+ return renderer.create_conversation_prefix_with_tools(tools, system_prompt)
+ return _create_conversation_prefix_with_tools_fallback(tools, system_prompt)
+
+
+def compute_advantages(
+ rewards: list[float], normalize_advantages: bool = True
+) -> list[float]:
+ if not rewards:
+ return []
+ rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
+ centered = rewards_tensor - rewards_tensor.mean()
+ if not normalize_advantages:
+ return centered.tolist()
+ std_reward = rewards_tensor.std()
+ if std_reward > 1e-8:
+ return (centered / std_reward).tolist()
+ return [0.0] * len(rewards)
+
+
+def convert_openai_messages_to_renderer_format(
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None,
+ renderer: Any,
+) -> list[dict[str, Any]]:
+ if tools and len(messages) > 0 and messages[0].get("role") == "system":
+ original_system = messages[0].get("content", "")
+
+ tool_specs = []
+ for tool in tools:
+ if tool.get("type") == "function":
+ func = tool.get("function", {})
+ tool_specs.append(func)
+ else:
+ tool_specs.append(tool)
+
+ tool_messages = create_conversation_prefix_with_tools(
+ renderer, tool_specs, system_prompt=original_system
+ )
+
+ converted = list(tool_messages)
+ messages = messages[1:]
+ else:
+ converted = []
+
+ for msg in messages:
+ role = msg.get("role")
+ content = msg.get("content", "")
+
+ if role == "system":
+ converted.append({"role": "system", "content": content})
+
+ elif role == "user":
+ converted.append({"role": "user", "content": content})
+
+ elif role == "assistant":
+ assistant_msg: dict[str, Any] = {
+ "role": "assistant",
+ "content": content or "",
+ }
+
+ if "tool_calls" in msg and msg["tool_calls"]:
+ tool_calls = []
+ for tool_call in msg["tool_calls"]:
+ func = tool_call.get("function", {})
+ tool_calls.append(
+ renderers.ToolCall(
+ id=tool_call.get("id", ""),
+ function=renderers.ToolCall.FunctionBody(
+ name=func.get("name", ""),
+ arguments=func.get("arguments", "{}"),
+ ),
+ )
+ )
+ assistant_msg["tool_calls"] = tool_calls
+
+ converted.append(assistant_msg)
+
+ elif role == "tool":
+ converted.append(
+ {
+ "role": "tool",
+ "content": content,
+ "tool_call_id": msg.get("tool_call_id", ""),
+ "name": msg.get("name", ""),
+ }
+ )
+
+ return converted
+
+
+def _extract_gpt_oss_tool_calls(content: str) -> tuple[str, list[dict[str, Any]]]:
+ tool_calls = []
+ cleaned_content = content
+
+ pattern = r"(\{[^}]*\})(?:<\|call\|>)?"
+
+ matches = list(re.finditer(pattern, content))
+ for i, match in enumerate(matches):
+ func_name = match.group(1)
+ args_json = match.group(2)
+
+ tool_calls.append(
+ {
+ "id": f"call_{i}",
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": args_json,
+ },
+ }
+ )
+
+ cleaned_content = cleaned_content.replace(match.group(0), "").strip()
+
+ return cleaned_content, tool_calls
+
+
+def parse_completion_to_openai_message(
+ completion_tokens: list[int],
+ renderer: Any,
+) -> dict[str, Any]:
+ message, _ = renderer.parse_response(completion_tokens)
+
+ result: dict[str, Any] = {"role": "assistant"}
+
+ content = message.get("content", "")
+ if isinstance(content, str):
+ result["content"] = content
+ else:
+ text_parts = []
+ for part in content:
+ if part["type"] == "text":
+ text_parts.append(part["text"])
+ elif part["type"] == "thinking":
+ text_parts.append(part["thinking"])
+ result["content"] = "".join(text_parts)
+
+ if "tool_calls" in message and message["tool_calls"]:
+ result["tool_calls"] = [
+ {
+ "id": tool_call.id or f"call_{i}",
+ "type": "function",
+ "function": {
+ "name": tool_call.function.name,
+ "arguments": tool_call.function.arguments,
+ },
+ }
+ for i, tool_call in enumerate(message["tool_calls"])
+ ]
+ else:
+ if result.get("content") and " bool:
+ for message_or_choice in trajectory.messages_and_choices:
+ if isinstance(message_or_choice, Choice):
+ return True
+ for history in trajectory.additional_histories:
+ for message_or_choice in history.messages_and_choices:
+ if isinstance(message_or_choice, Choice):
+ return True
+ return False
+
+
+def trajectory_groups_to_datums(
+ trajectory_groups: Iterable[TrajectoryGroup],
+ renderer: Any,
+ tokenizer: Any,
+ normalize_advantages: bool = True,
+) -> list[tinker.Datum]:
+ datums: list[tinker.Datum] = []
+
+ for group in trajectory_groups:
+ if not group.trajectories:
+ continue
+ for trajectory in group.trajectories:
+ if not _trajectory_has_choice(trajectory):
+ raise ValueError(
+ "Trajectory is missing a Choice object. Training requires at least one Choice "
+ "to compute logprobs. Ensure your rollout includes an OpenAI Choice in "
+ "Trajectory.messages_and_choices."
+ )
+ rewards = [trajectory.reward for trajectory in group.trajectories]
+ advantages = compute_advantages(rewards, normalize_advantages)
+
+ if all(advantage == 0.0 for advantage in advantages):
+ continue
+ for trajectory, advantage in zip(group.trajectories, advantages):
+ for history in iter_trajectory_histories(trajectory):
+ datum = history_to_datum(history, advantage, renderer, tokenizer)
+ if datum is not None:
+ datums.append(datum)
+
+ return datums
+
+
+def iter_trajectory_histories(trajectory: Trajectory) -> Iterable[History]:
+ yield History(
+ messages_and_choices=trajectory.messages_and_choices,
+ tools=trajectory.tools,
+ )
+ yield from trajectory.additional_histories
+
+
+def find_last_choice(
+ messages_and_choices: MessagesAndChoices,
+) -> tuple[int, Choice] | None:
+ for idx in range(len(messages_and_choices) - 1, -1, -1):
+ message = messages_and_choices[idx]
+ if isinstance(message, Choice):
+ return idx, message
+ return None
+
+
+def extract_logprobs_from_choice(
+ choice: Choice, tokenizer: Any
+) -> tuple[list[int], list[float]]:
+ if choice.logprobs is None:
+ return [], []
+ token_logprobs = choice.logprobs.content or choice.logprobs.refusal or []
+ tokens: list[int] = []
+ logprobs: list[float] = []
+ for token_logprob in token_logprobs:
+ token_str = token_logprob.token or ""
+ if token_str.startswith("token_id:"):
+ try:
+ token_id = int(token_str.split(":")[1])
+ except ValueError:
+ continue
+ tokens.append(token_id)
+ logprobs.append(token_logprob.logprob)
+ else:
+ token_id = tokenizer.convert_tokens_to_ids(token_str)
+ if token_id is None:
+ continue
+ tokens.append(int(token_id))
+ logprobs.append(token_logprob.logprob)
+ return tokens, logprobs
+
+
+def history_to_datum(
+ history: History,
+ advantage: float,
+ renderer: Any,
+ tokenizer: Any,
+) -> tinker.Datum | None:
+ choice_info = find_last_choice(history.messages_and_choices)
+ if choice_info is None:
+ return None
+ choice_index, choice = choice_info
+
+ completion_tokens, logprobs = extract_logprobs_from_choice(choice, tokenizer)
+ if not completion_tokens or len(completion_tokens) != len(logprobs):
+ return None
+
+ prompt_messages = cast(
+ list[dict[str, Any]], get_messages(history.messages_and_choices[:choice_index])
+ )
+ renderer_messages = convert_openai_messages_to_renderer_format(
+ messages=prompt_messages,
+ tools=cast(list[dict[str, Any]] | None, history.tools),
+ renderer=renderer,
+ )
+ prompt_input = renderer.build_generation_prompt(renderer_messages)
+ prompt_tokens = list(prompt_input.to_ints())
+
+ return build_datum(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ logprobs=logprobs,
+ advantage=advantage,
+ )
+
+
+def build_datum(
+ prompt_tokens: list[int],
+ completion_tokens: list[int],
+ logprobs: list[float],
+ advantage: float,
+) -> tinker.Datum | None:
+ if not prompt_tokens or not completion_tokens:
+ return None
+ ob_len = max(len(prompt_tokens) - 1, 0)
+
+ all_tokens = prompt_tokens + completion_tokens
+ input_tokens = all_tokens[:-1]
+ target_tokens = all_tokens[1:]
+
+ padded_logprobs = [0.0] * ob_len + list(logprobs)
+ padded_advantages = [0.0] * ob_len + [advantage] * len(completion_tokens)
+ action_mask = [0.0] * ob_len + [1.0] * len(completion_tokens)
+
+ if not (
+ len(input_tokens)
+ == len(target_tokens)
+ == len(padded_logprobs)
+ == len(padded_advantages)
+ == len(action_mask)
+ ):
+ return None
+
+ return tinker.Datum(
+ model_input=tinker.ModelInput.from_ints(tokens=input_tokens),
+ loss_fn_inputs={
+ "target_tokens": tinker.TensorData.from_torch(torch.tensor(target_tokens)),
+ "logprobs": tinker.TensorData.from_torch(
+ torch.tensor(padded_logprobs, dtype=torch.float32)
+ ),
+ "advantages": tinker.TensorData.from_torch(
+ torch.tensor(padded_advantages, dtype=torch.float32)
+ ),
+ "mask": tinker.TensorData.from_torch(
+ torch.tensor(action_mask, dtype=torch.float32)
+ ),
+ },
+ )
diff --git a/src/art/trajectories.py b/src/art/trajectories.py
index f74fb3039..a04762463 100644
--- a/src/art/trajectories.py
+++ b/src/art/trajectories.py
@@ -40,6 +40,8 @@ class Trajectory(pydantic.BaseModel):
tools: Tools | None = None
additional_histories: list[History] = []
reward: float
+ initial_policy_version: int | None = None
+ final_policy_version: int | None = None
metrics: dict[str, float | int | bool] = {}
auto_metrics: dict[str, float | int | bool] = {}
metadata: dict[str, MetadataValue] = {}
@@ -78,6 +80,8 @@ def messages(self) -> Messages:
def for_logging(self) -> dict[str, Any]:
loggable_dict = {
"reward": self.reward,
+ "initial_policy_version": self.initial_policy_version,
+ "final_policy_version": self.final_policy_version,
"metrics": self.metrics,
"metadata": self.metadata,
"messages": [],
diff --git a/tests/integration/test_multi_checkpoint_training.py b/tests/integration/test_multi_checkpoint_training.py
index 9252f59a4..5d74da0e4 100644
--- a/tests/integration/test_multi_checkpoint_training.py
+++ b/tests/integration/test_multi_checkpoint_training.py
@@ -61,7 +61,7 @@ async def simple_rollout(
async def run_training_loop(
model: art.TrainableModel,
- backend: Union[LocalBackend, art.ServerlessBackend, art.TinkerBackend],
+ backend: art.Backend,
num_steps: int = 1,
rollouts_per_step: int = 4,
) -> list[TrainResult]:
diff --git a/tests/integration/test_tinker_native_backend.py b/tests/integration/test_tinker_native_backend.py
new file mode 100644
index 000000000..bbaa72729
--- /dev/null
+++ b/tests/integration/test_tinker_native_backend.py
@@ -0,0 +1,114 @@
+"""Integration test for TinkerNativeBackend based on yes-no-maybe."""
+
+import os
+import tempfile
+import uuid
+
+import openai
+import pytest
+
+import art
+
+DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+
+
+def get_base_model() -> str:
+ return os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL)
+
+
+def ensure_reward_variance(groups) -> None:
+ for group in groups:
+ rewards = [t.reward for t in group]
+ if len(rewards) < 2:
+ continue
+ if len(set(rewards)) <= 1:
+ group.trajectories[0].reward = 1.0
+ group.trajectories[1].reward = 0.0
+
+
+async def simple_rollout(
+ client: openai.AsyncOpenAI, model_name: str, prompt: str
+) -> art.Trajectory:
+ messages: art.Messages = [{"role": "user", "content": prompt}]
+ chat_completion = await client.chat.completions.create(
+ messages=messages,
+ model=model_name,
+ max_tokens=10,
+ timeout=60,
+ temperature=1,
+ )
+ choice = chat_completion.choices[0]
+ content = (choice.message.content or "").lower()
+ if "yes" in content:
+ reward = 1.0
+ elif "no" in content:
+ reward = 0.5
+ elif "maybe" in content:
+ reward = 0.25
+ else:
+ reward = 0.0
+ return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) # type: ignore[attr-defined]
+
+
+@pytest.mark.skipif(
+ "TINKER_API_KEY" not in os.environ,
+ reason="TINKER_API_KEY not set - skipping TinkerNativeBackend test",
+)
+async def test_tinker_native_backend():
+ model_name = f"test-tinker-native-{uuid.uuid4().hex[:8]}"
+ with tempfile.TemporaryDirectory() as tmpdir:
+ backend = art.TinkerNativeBackend(path=tmpdir) # type: ignore[attr-defined]
+ model = art.TrainableModel( # type: ignore[attr-defined]
+ name=model_name,
+ project="integration-tests",
+ base_model=get_base_model(),
+ )
+ try:
+ await model.register(backend)
+
+ openai_client = model.openai_client()
+ current_step = await model.get_step()
+ model_name_step = model.get_inference_name(step=current_step)
+ prompts = ["Say yes", "Say no", "Say maybe"]
+
+ async def make_group(prompt: str) -> art.TrajectoryGroup:
+ import asyncio
+
+ trajectories = await asyncio.gather(
+ *[
+ simple_rollout(openai_client, model_name_step, prompt)
+ for _ in range(2)
+ ]
+ )
+ return art.TrajectoryGroup(trajectories) # type: ignore[attr-defined]
+
+ train_groups = await art.gather_trajectory_groups( # type: ignore[attr-defined]
+ [make_group(prompt) for prompt in prompts]
+ )
+ ensure_reward_variance(train_groups)
+
+ result = await backend.train(
+ model,
+ train_groups,
+ learning_rate=1e-5,
+ )
+ await model.log(
+ train_groups, metrics=result.metrics, step=result.step, split="train"
+ )
+
+ assert result.step > current_step
+
+ await openai_client.chat.completions.create(
+ messages=[{"role": "user", "content": "Say hello"}],
+ model=model.get_inference_name(step=result.step),
+ max_tokens=10,
+ timeout=30,
+ )
+ await openai_client.chat.completions.create(
+ messages=[{"role": "user", "content": "Say hello"}],
+ model=model.get_inference_name(step=0),
+ max_tokens=10,
+ timeout=30,
+ )
+ finally:
+ await backend.close()
diff --git a/uv.lock b/uv.lock
index 1e73e5e41..fcc458ce5 100644
--- a/uv.lock
+++ b/uv.lock
@@ -314,9 +314,12 @@ wheels = [
[[package]]
name = "antlr4-python3-runtime"
-version = "4.9.3"
+version = "4.13.2"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/33/5f/2cdf6f7aca3b20d3f316e9f505292e1f256a32089bd702034c29ebde6242/antlr4_python3_runtime-4.13.2.tar.gz", hash = "sha256:909b647e1d2fc2b70180ac586df3933e38919c85f98ccc656a96cd3f25ef3916", size = 117467, upload-time = "2024-08-03T19:00:12.757Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/89/03/a851e84fcbb85214dc637b6378121ef9a0dd61b4c65264675d8a5c9b1ae7/antlr4_python3_runtime-4.13.2-py3-none-any.whl", hash = "sha256:fe3835eb8d33daece0e799090eda89719dbccee7aa39ef94eed3818cafa5a7e8", size = 144462, upload-time = "2024-08-03T19:00:11.134Z" },
+]
[[package]]
name = "anyio"
@@ -2918,7 +2921,7 @@ wheels = [
[[package]]
name = "inspect-ai"
-version = "0.3.158"
+version = "0.3.163"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aioboto3" },
@@ -2953,13 +2956,14 @@ dependencies = [
{ name = "sniffio" },
{ name = "tenacity" },
{ name = "textual" },
+ { name = "tiktoken" },
{ name = "typing-extensions" },
{ name = "universal-pathlib" },
{ name = "zipp" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/99/ed/39847f3251ad4cc78553be3ebe458cf9962614ea17badaef0578e27d0221/inspect_ai-0.3.158.tar.gz", hash = "sha256:08b8025341c5815075e2f3fdcb2313074b0dc7a22d6ac8ee48306d7bc92700e0", size = 43435860, upload-time = "2025-12-24T12:03:47.24Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/d8/1f/ceaff3a92c03196cc2503a1ec8cc865ca4695a7e25a20f3c9fb9892664da/inspect_ai-0.3.163.tar.gz", hash = "sha256:4a3b131a1d48430bf6d64ab9842fababf1ce66d64aa126f96ab09f399c4f9f61", size = 43358268, upload-time = "2026-01-21T20:36:44.792Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/d9/ca/308f85765bcbc681bcb9f308a2443fb8a91d57b80e2e74f64e734cb6f946/inspect_ai-0.3.158-py3-none-any.whl", hash = "sha256:80d895b62c1cfea9cca20c54dffc2f7a23e822634b1cce45af7b03eb4fb92885", size = 34682434, upload-time = "2025-12-24T12:03:39.798Z" },
+ { url = "https://files.pythonhosted.org/packages/5e/a0/bc25e3c895ff462f8b901784813a4241bdef8ed6aed66837f757a5e36747/inspect_ai-0.3.163-py3-none-any.whl", hash = "sha256:c09fd251d184a77f7a69fdd75695c457ed1c328fee4dafeabd9232f7309c6741", size = 34559953, upload-time = "2026-01-21T20:36:36.141Z" },
]
[[package]]
@@ -3632,15 +3636,15 @@ wheels = [
[[package]]
name = "latex2sympy2-extended"
-version = "1.10.2"
+version = "1.11.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "antlr4-python3-runtime" },
{ name = "sympy" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/f4/de/472f9115c14c6f6d8a5889cabe3418283d708bde62ce00402c29441deed4/latex2sympy2_extended-1.10.2.tar.gz", hash = "sha256:41a517ffcc5a140e910a7d1646ce6ff440817e5f9d48fc8279d88bd0925bc389", size = 206188, upload-time = "2025-07-02T15:26:06.225Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/30/75/456da2da05f6380ea96e6ea804ab2c03e41fc3ed80052307fe8efe6ea20e/latex2sympy2_extended-1.11.0.tar.gz", hash = "sha256:9695657c81b50abba2636638638618db59f4663ed2a4a12d62cef74a40e28fec", size = 207023, upload-time = "2026-01-10T01:43:21.319Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ab/60/dfbbf40e3a371388c0e03ff65b01319b7d4023e883df6d7261125772ffdc/latex2sympy2_extended-1.10.2-py3-none-any.whl", hash = "sha256:f910442c5b02a466c1046f47d05cc5285181068b882399281f30102715337fb7", size = 207855, upload-time = "2025-07-02T15:26:04.88Z" },
+ { url = "https://files.pythonhosted.org/packages/e9/61/f75cd1fa54d8434276126034aed54dd120747de9a8fa013cdd79545ccbeb/latex2sympy2_extended-1.11.0-py3-none-any.whl", hash = "sha256:aebb77d52ce269e25028e4bea89ddb14d242ba36bcf7b636496fb5fd9728d234", size = 209050, upload-time = "2026-01-10T01:43:19.458Z" },
]
[[package]]
@@ -3950,14 +3954,14 @@ wheels = [
[[package]]
name = "math-verify"
-version = "0.8.0"
+version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "latex2sympy2-extended" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/35/b5/b1db6fa6b6c28ebbe1889ee11a4703a72a2ca7750ec415f4559c758cf01a/math_verify-0.8.0.tar.gz", hash = "sha256:3295e0adb94bfe553ff6e3189c44f1916a85aa24ab5d1900f2086a706e28f7c4", size = 60191, upload-time = "2025-07-02T15:52:07.209Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/4f/12/b8d13b581e110ac2f724a2351a8361a70fa36d057eb945d6379e8747c256/math_verify-0.9.0.tar.gz", hash = "sha256:45ac6c61344ba056b9e99a660a4bc8d044ed408f730aed68c60435aa5eec4645", size = 60329, upload-time = "2026-01-10T01:48:33.056Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/fe/9f/59979f699b5c97334298f1295bc9fcdc9904d98d2276479bffff863d23b1/math_verify-0.8.0-py3-none-any.whl", hash = "sha256:31ca651296d817a9bb3fd58ca1fd0d192dcea709b1e5ecf2d0a4514c16f89087", size = 29994, upload-time = "2025-07-02T15:52:05.023Z" },
+ { url = "https://files.pythonhosted.org/packages/62/76/6b4969bccc842b6567f7e6ee015684b9428a9b7fcbdf479e73716f43597f/math_verify-0.9.0-py3-none-any.whl", hash = "sha256:3703e7c4885354027fa84409d762a596a2906d1fd4deb78361876bd905a76194", size = 29967, upload-time = "2026-01-10T01:48:31.674Z" },
]
[[package]]
@@ -5034,8 +5038,6 @@ dependencies = [
{ name = "polars" },
{ name = "setproctitle" },
{ name = "tblib" },
- { name = "tinker" },
- { name = "tinker-cookbook" },
{ name = "typer" },
{ name = "weave" },
]
@@ -5093,6 +5095,12 @@ dev = [
{ name = "ruff" },
{ name = "ty" },
]
+tinker = [
+ { name = "fastapi" },
+ { name = "tinker" },
+ { name = "tinker-cookbook" },
+ { name = "uvicorn" },
+]
[package.metadata]
requires-dist = [
@@ -5121,8 +5129,6 @@ requires-dist = [
{ name = "setuptools", marker = "extra == 'backend'", specifier = ">=78.1.0" },
{ name = "skypilot", extras = ["cudo", "do", "fluidstack", "gcp", "lambda", "kubernetes", "paperspace", "runpod"], marker = "extra == 'skypilot'", specifier = "==0.10.5" },
{ name = "tblib", specifier = ">=3.0.0" },
- { name = "tinker", specifier = ">=0.8.1" },
- { name = "tinker-cookbook", specifier = ">=0.1.0" },
{ name = "torch", marker = "extra == 'backend'", specifier = ">=2.8.0" },
{ name = "torchao", marker = "extra == 'backend'", specifier = "==0.14.1" },
{ name = "transformers", marker = "extra == 'backend'", specifier = ">=4.55.2,<=4.57.3" },
@@ -5152,6 +5158,12 @@ dev = [
{ name = "ruff", specifier = ">=0.12.1" },
{ name = "ty", specifier = ">=0.0.14" },
]
+tinker = [
+ { name = "fastapi", specifier = ">=0.128.0" },
+ { name = "tinker", specifier = ">=0.8.1" },
+ { name = "tinker-cookbook", specifier = ">=0.1.0" },
+ { name = "uvicorn", specifier = ">=0.35.0" },
+]
[[package]]
name = "orjson"
@@ -7946,11 +7958,11 @@ wheels = [
[[package]]
name = "soupsieve"
-version = "2.8.1"
+version = "2.8.3"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/89/23/adf3796d740536d63a6fbda113d07e60c734b6ed5d3058d1e47fc0495e47/soupsieve-2.8.1.tar.gz", hash = "sha256:4cf733bc50fa805f5df4b8ef4740fc0e0fa6218cf3006269afd3f9d6d80fd350", size = 117856, upload-time = "2025-12-18T13:50:34.655Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/7b/ae/2d9c981590ed9999a0d91755b47fc74f74de286b0f5cee14c9269041e6c4/soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349", size = 118627, upload-time = "2026-01-20T04:27:02.457Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/48/f3/b67d6ea49ca9154453b6d70b34ea22f3996b9fa55da105a79d8732227adc/soupsieve-2.8.1-py3-none-any.whl", hash = "sha256:a11fe2a6f3d76ab3cf2de04eb339c1be5b506a8a47f2ceb6d139803177f85434", size = 36710, upload-time = "2025-12-18T13:50:33.267Z" },
+ { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" },
]
[[package]]
@@ -8134,7 +8146,7 @@ wheels = [
[[package]]
name = "textual"
-version = "6.12.0"
+version = "7.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py", extra = ["linkify"] },
@@ -8144,9 +8156,9 @@ dependencies = [
{ name = "rich" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/39/55/29416ef63de4c37b37da217b94439a28496a4dc585209f5bf1437a61d120/textual-6.12.0.tar.gz", hash = "sha256:a32e8edbf6abdb0c42d486e96bdf419eb3aa378edb1b1271b84637f3dbd64c73", size = 1584182, upload-time = "2026-01-02T09:42:30.415Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/6f/ee/620c887bfad9d6eba062dfa3b6b0e735e0259102e2667b19f21625ef598d/textual-7.3.0.tar.gz", hash = "sha256:3169e8ba5518a979b0771e60be380ab1a6c344f30a2126e360e6f38d009a3de4", size = 1590692, upload-time = "2026-01-15T16:32:02.342Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/13/f8/2a6a6ff1d07788f635493867d5a4003dfecacad16af1fdc9814d10daca3d/textual-6.12.0-py3-none-any.whl", hash = "sha256:cf9ea9a54d213c7736efe9fef440c7f49218d4e6ab75279afd060eded9c567ec", size = 714912, upload-time = "2026-01-02T09:42:28.786Z" },
+ { url = "https://files.pythonhosted.org/packages/c3/1f/abeb4e5cb36b99dd37db72beb2a74d58598ccb35aaadf14624ee967d4a6b/textual-7.3.0-py3-none-any.whl", hash = "sha256:db235cecf969c87fe5a9c04d83595f506affc9db81f3a53ab849534d726d330a", size = 716374, upload-time = "2026-01-15T16:31:58.233Z" },
]
[[package]]
@@ -8205,7 +8217,7 @@ wheels = [
[[package]]
name = "tinker"
-version = "0.8.1"
+version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -8219,9 +8231,9 @@ dependencies = [
{ name = "transformers" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/ec/a2/fe985e880a4a0a7ee47a0fa4824f5a1d1b24c27b7dd8a32e2606838b42a0/tinker-0.8.1.tar.gz", hash = "sha256:5ccf49f2ad7ca2dade303e9e7c33a1d21e14948d051e2a70e4a4f89d4fa52abf", size = 171213, upload-time = "2026-01-21T23:13:42.738Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/a0/f6/5ad5faa0e34f18f4d921fb1afab038b4fa2f662cd199b78072c486840d75/tinker-0.9.0.tar.gz", hash = "sha256:d47dd5870f3f3c982ad515254903753ccb626a2ee096e72ebaa90beb24573926", size = 173090, upload-time = "2026-01-26T22:33:58.998Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/70/d6/41d8c5ad6145f360f7bb476ff69cfa836cd961ffe7492d7ab3fa0da46bc1/tinker-0.8.1-py3-none-any.whl", hash = "sha256:ec743006d596735e8197b3c43adb03bdb46a5cbc01c02b920a0b7f60743c3377", size = 168164, upload-time = "2026-01-21T23:13:44.192Z" },
+ { url = "https://files.pythonhosted.org/packages/e7/00/0282156cf66331e3f2dc0f8cb7020886fdbe6843771d3afac810c94f2638/tinker-0.9.0-py3-none-any.whl", hash = "sha256:e7c4a476a3c68799654021807cd9e1a4b3954f664b30f60fe613caeb774d7f94", size = 168536, upload-time = "2026-01-26T22:33:57.478Z" },
]
[[package]]
@@ -8792,15 +8804,15 @@ wheels = [
[[package]]
name = "universal-pathlib"
-version = "0.3.7"
+version = "0.3.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "fsspec" },
{ name = "pathlib-abc" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/d5/96/b58b00ce27cbc37fd3c79944438dd8630d2c39f9467c9e73e1a4a3eec1ef/universal_pathlib-0.3.7.tar.gz", hash = "sha256:36331056fa59a7d7cd3b61b4045f3a3418f446f23ec1a01d281c4510814b4b05", size = 253466, upload-time = "2025-12-03T00:06:43.859Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/6e/ec/764b0d4593c6a8f5f66b347a19b5db9486dd0024b5e3339d468064a90c76/universal_pathlib-0.3.8.tar.gz", hash = "sha256:ead2b65bca3df6e11c3b7cb36fc9846340bc3c2db4ef57131550260422b0a3e8", size = 258837, upload-time = "2026-01-11T22:13:53.328Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/79/77/53c2d3a0413bc55b4c91067a0c41e55451be9b0784d204e4e46ce5abf668/universal_pathlib-0.3.7-py3-none-any.whl", hash = "sha256:fb95117b20b5981f86ef9d887fddbf9c61d3596634ba42cccea444931d87c201", size = 80738, upload-time = "2025-12-03T00:06:41.997Z" },
+ { url = "https://files.pythonhosted.org/packages/86/2c/fc9416619a418e94576aef84ef263906a24f76a21a1c3e96ddae25c82df9/universal_pathlib-0.3.8-py3-none-any.whl", hash = "sha256:dac4fd9a3df918d85bb6da678e794b5dfa9ecdb5ff74675b497553dbe50134b8", size = 82608, upload-time = "2026-01-11T22:13:51.313Z" },
]
[[package]]