From 9e9148d11a1cf06f696ef1e029483f99eec4cd00 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Sat, 31 Jan 2026 02:55:22 +0000 Subject: [PATCH] Move cost calculation to backend --- src/art/costs.py | 121 ++++++++++++ src/art/model.py | 172 +++++++++++++++++- .../binary_prefix_tool_pipeline.py | 11 +- src/art/pipeline_trainer/trainer.py | 40 +++- src/art/pipeline_trainer/types.py | 5 +- src/art/tinker/cookbook_v/__init__.py | 1 + src/art/tinker/cookbook_v/utils/__init__.py | 1 + src/art/tinker_native/backend.py | 12 ++ 8 files changed, 348 insertions(+), 15 deletions(-) create mode 100644 src/art/costs.py diff --git a/src/art/costs.py b/src/art/costs.py new file mode 100644 index 000000000..5ee5523a9 --- /dev/null +++ b/src/art/costs.py @@ -0,0 +1,121 @@ +"""Cost utilities for ART training and evaluation.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, TypeAlias + + +@dataclass(frozen=True) +class ModelPricing: + """Per-million-token pricing for a model.""" + + prefill: float # $/1M tokens for prompt/prefill + sample: float # $/1M tokens for sampling/generation + train: float # $/1M tokens for training + + +TokenCount: TypeAlias = int | None +CostCalculator: TypeAlias = Callable[[TokenCount, TokenCount], dict[str, float]] + +# Pricing per model ($/1M tokens). Keep in sync with infra pricing. +MODEL_PRICING: dict[str, ModelPricing] = { + # Qwen models + "Qwen/Qwen3-4B-Instruct-2507": ModelPricing(prefill=0.07, sample=0.22, train=0.22), + "Qwen/Qwen3-8B": ModelPricing(prefill=0.13, sample=0.40, train=0.40), + "Qwen/Qwen3-8B-Base": ModelPricing(prefill=0.13, sample=0.40, train=0.40), + "Qwen/Qwen3-30B-A3B": ModelPricing(prefill=0.12, sample=0.30, train=0.36), + "Qwen/Qwen3-30B-A3B-Base": ModelPricing(prefill=0.12, sample=0.30, train=0.36), + "Qwen/Qwen3-30B-A3B-Instruct-2507": ModelPricing( + prefill=0.12, sample=0.30, train=0.36 + ), + "Qwen/Qwen3-32B": ModelPricing(prefill=0.49, sample=1.47, train=1.47), + "Qwen/Qwen3-235B-A22B-Instruct-2507": ModelPricing( + prefill=0.68, sample=1.70, train=2.04 + ), + "Qwen/Qwen3-VL-30B-A3B-Instruct": ModelPricing( + prefill=0.18, sample=0.44, train=0.53 + ), + "Qwen/Qwen3-VL-235B-A22B-Instruct": ModelPricing( + prefill=1.02, sample=2.56, train=3.07 + ), + # Meta Llama models + "meta-llama/Llama-3.2-1B": ModelPricing(prefill=0.03, sample=0.09, train=0.09), + "meta-llama/Llama-3.2-3B": ModelPricing(prefill=0.06, sample=0.18, train=0.18), + "meta-llama/Llama-3.1-8B": ModelPricing(prefill=0.13, sample=0.40, train=0.40), + "meta-llama/Llama-3.1-8B-Instruct": ModelPricing( + prefill=0.13, sample=0.40, train=0.40 + ), + "meta-llama/Llama-3.1-70B": ModelPricing(prefill=1.05, sample=3.16, train=3.16), + "meta-llama/Llama-3.3-70B-Instruct": ModelPricing( + prefill=1.05, sample=3.16, train=3.16 + ), + # DeepSeek models + "deepseek-ai/DeepSeek-V3.1": ModelPricing(prefill=1.13, sample=2.81, train=3.38), + "deepseek-ai/DeepSeek-V3.1-Base": ModelPricing( + prefill=1.13, sample=2.81, train=3.38 + ), + # OpenAI models + "openai/gpt-oss-120b": ModelPricing(prefill=0.18, sample=0.44, train=0.52), + "openai/gpt-oss-20b": ModelPricing(prefill=0.12, sample=0.30, train=0.36), + # Moonshot models + "moonshotai/Kimi-K2-Thinking": ModelPricing(prefill=0.98, sample=2.44, train=2.93), +} + + +def get_model_pricing( + model_name: str | None, *, strict: bool = False +) -> ModelPricing | None: + """Return pricing for a model or None if missing.""" + if model_name is None: + return None + pricing = MODEL_PRICING.get(model_name) + if pricing is None and strict: + raise ValueError( + f"No pricing configured for model '{model_name}'. " + f"Add pricing to art.costs.MODEL_PRICING. " + f"Available models: {list(MODEL_PRICING.keys())}" + ) + return pricing + + +def tokens_to_cost(num_tokens: float, price_per_million: float) -> float: + """Convert token count to cost in dollars.""" + return float(num_tokens) * price_per_million / 1_000_000 + + +def compute_sample_costs( + *, + prompt_tokens: int | None, + completion_tokens: int | None, + pricing: ModelPricing, +) -> dict[str, float]: + """Compute prompt+completion costs for a single API call.""" + prompt_value = float(prompt_tokens or 0) + completion_value = float(completion_tokens or 0) + prefill_cost = tokens_to_cost(prompt_value, pricing.prefill) + sample_cost = tokens_to_cost(completion_value, pricing.sample) + return { + "costs_prefill": prefill_cost, + "costs_sample": sample_cost, + } + + +def build_cost_calculator(pricing: ModelPricing) -> CostCalculator: + """Return a callable that computes prompt+completion costs for a request.""" + + def _calculator( + prompt_tokens: int | None, completion_tokens: int | None + ) -> dict[str, float]: + return compute_sample_costs( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + pricing=pricing, + ) + + return _calculator + + +def compute_train_cost(train_tokens: float, pricing: ModelPricing) -> float: + """Compute training cost from token count.""" + return tokens_to_cost(train_tokens, pricing.train) diff --git a/src/art/model.py b/src/art/model.py index e503c6f6a..21244ad8c 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime import json import os @@ -11,6 +12,7 @@ from typing_extensions import Never, TypeVar from . import dev +from .costs import CostCalculator from .trajectories import Trajectory, TrajectoryGroup from .types import TrainConfig from .utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev @@ -25,6 +27,10 @@ ModelConfig = TypeVar("ModelConfig", bound=BaseModel | None) StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any]) +COSTS_STATE_KEY = "_costs" +COSTS_METRIC_PREFIX = "costs_" +COSTS_TOTAL_KEY = f"{COSTS_METRIC_PREFIX}total" + class Model( BaseModel, @@ -87,6 +93,8 @@ class Model( _s3_prefix: str | None = None _openai_client: AsyncOpenAI | None = None _wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization + _costs_lock: asyncio.Lock + _cost_calculator: CostCalculator def __init__( self, @@ -374,6 +382,7 @@ def _get_wandb_run(self) -> Optional["Run"]: wandb.define_metric("training_step") wandb.define_metric("train/*", step_metric="training_step") wandb.define_metric("val/*", step_metric="training_step") + wandb.define_metric("costs/*", step_metric="training_step") return self._wandb_run def _log_metrics( @@ -406,6 +415,64 @@ def _log_metrics( if run := self._get_wandb_run(): run.log({"training_step": step, **prefixed}) + async def _record_costs( + self, + split: str, + step: int, + *, + cost_components: dict[str, float], + cost_total_direct: float, + cost_seen: bool, + ) -> None: + component_total = sum(cost_components.values()) + step_total = component_total if component_total > 0 else cost_total_direct + if not cost_seen or step_total <= 0: + return + + async with self._costs_lock: + existing_state = self.read_state() or {} + raw_costs = existing_state.get(COSTS_STATE_KEY) or {} + cumulative = { + key: float(value) + for key, value in raw_costs.items() + if isinstance(value, (int, float)) + } + last_steps = raw_costs.get("_last_steps") + if not isinstance(last_steps, dict): + last_steps = {} + last_step = last_steps.get(split) + + if isinstance(last_step, (int, float)) and int(last_step) >= step: + for component, value in cost_components.items(): + if value == 0: + continue + cumulative_key = f"{split}_{component}" + cumulative[cumulative_key] = max( + cumulative.get(cumulative_key, 0.0), value + ) + cumulative[split] = max(cumulative.get(split, 0.0), step_total) + cumulative["total"] = max( + cumulative.get("total", 0.0), cumulative.get(split, 0.0) + ) + self.merge_state( + {COSTS_STATE_KEY: {**cumulative, "_last_steps": last_steps}} + ) + self._log_metrics(cumulative, "costs", step) + return + + for component, value in cost_components.items(): + if value == 0: + continue + cumulative_key = f"{split}_{component}" + cumulative[cumulative_key] = cumulative.get(cumulative_key, 0.0) + value + cumulative[split] = cumulative.get(split, 0.0) + step_total + cumulative["total"] = cumulative.get("total", 0.0) + step_total + last_steps[split] = step + self.merge_state( + {COSTS_STATE_KEY: {**cumulative, "_last_steps": last_steps}} + ) + self._log_metrics(cumulative, "costs", step) + async def log( self, trajectories: ( @@ -439,7 +506,42 @@ async def log( # If only metrics provided (no trajectories), just log them and return if trajectories is None: if metrics is not None: - self._log_metrics(metrics, split, step) + cost_step = await self.get_step() + cost_components: dict[str, float] = {} + cost_total_direct = 0.0 + cost_seen = False + + for metric, value in metrics.items(): + if not isinstance(value, (int, float)): + continue + if metric == COSTS_TOTAL_KEY: + raise ValueError( + "Do not log 'costs_total' directly. Log costs_* components " + "(e.g., costs_prefill, costs_sample) and totals are derived." + ) + elif metric.startswith(COSTS_METRIC_PREFIX): + component = metric[len(COSTS_METRIC_PREFIX) :] + if component: + cost_components[component] = cost_components.get( + component, 0.0 + ) + float(value) + cost_seen = True + + metrics_without_costs = { + key: value + for key, value in metrics.items() + if not key.startswith(COSTS_METRIC_PREFIX) + } + if metrics_without_costs: + self._log_metrics(metrics_without_costs, split, step) + + await self._record_costs( + split, + cost_step, + cost_components=cost_components, + cost_total_direct=cost_total_direct, + cost_seen=cost_seen, + ) return # Convert to list[TrajectoryGroup] @@ -465,13 +567,39 @@ async def log( trajectory_groups, f"{trajectories_dir}/{file_name}" ) - # 2. Calculate aggregate metrics + # 2. Calculate aggregate metrics (excluding additive costs) + cost_step = await self.get_step() all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []} group_metrics: dict[str, list[float]] = {} + cost_components: dict[str, float] = {} + cost_total_direct = 0.0 + cost_seen = False + + def _add_costs(metrics_dict: dict[str, float | int | bool]) -> None: + nonlocal cost_total_direct, cost_seen + for metric, value in metrics_dict.items(): + if not isinstance(value, (int, float)): + continue + if metric == COSTS_TOTAL_KEY: + raise ValueError( + "Do not log 'costs_total' directly. Log costs_* components " + "(e.g., costs_prefill, costs_sample) and totals are derived." + ) + elif metric.startswith(COSTS_METRIC_PREFIX): + component = metric[len(COSTS_METRIC_PREFIX) :] + if component: + cost_components[component] = cost_components.get( + component, 0.0 + ) + float(value) + cost_seen = True for group in trajectory_groups: + if group.metrics: + _add_costs(group.metrics) if group.trajectories: for metric, value in group.metrics.items(): + if metric.startswith(COSTS_METRIC_PREFIX): + continue if metric not in group_metrics: group_metrics[metric] = [] group_metrics[metric].append(float(value)) @@ -486,9 +614,13 @@ async def log( # Collect other custom metrics for metric, value in trajectory.metrics.items(): + if metric.startswith(COSTS_METRIC_PREFIX): + continue if metric not in all_metrics: all_metrics[metric] = [] all_metrics[metric].append(float(value)) + if trajectory.metrics: + _add_costs(trajectory.metrics) # Calculate averages for all metrics averages: dict[str, float] = {} @@ -506,11 +638,26 @@ async def log( # Merge in any additional metrics passed directly if metrics is not None: - averages.update(metrics) + _add_costs(metrics) + metrics_without_costs = { + key: value + for key, value in metrics.items() + if not key.startswith(COSTS_METRIC_PREFIX) + } + averages.update(metrics_without_costs) # 3. Log metrics (writes to history.jsonl and wandb) self._log_metrics(averages, split, step) + # 4. Log cumulative costs (additive) + await self._record_costs( + split, + cost_step, + cost_components=cost_components, + cost_total_direct=cost_total_direct, + cost_seen=cost_seen, + ) + async def get_step(self) -> int: """ Get the model's current training step. For non-trainable models, returns 0. @@ -559,6 +706,25 @@ def __init__( report_metrics=report_metrics, **kwargs, ) + object.__setattr__(self, "_costs_lock", asyncio.Lock()) + object.__setattr__(self, "_cost_calculator", self._noop_cost_calculator) + + @property + def cost_calculator(self) -> CostCalculator: + return self._cost_calculator + + def set_cost_calculator(self, calculator: CostCalculator | None) -> None: + object.__setattr__( + self, + "_cost_calculator", + calculator if calculator is not None else self._noop_cost_calculator, + ) + + @staticmethod + def _noop_cost_calculator( + _prompt_tokens: int | None, _completion_tokens: int | None + ) -> dict[str, float]: + return {} if _internal_config is not None: # Bypass BaseModel __setattr__ to allow setting private attr object.__setattr__(self, "_internal_config", _internal_config) diff --git a/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py index e646aed73..606a33318 100644 --- a/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py +++ b/src/art/pipeline_trainer/binary_prefix_tool_pipeline.py @@ -223,11 +223,11 @@ async def main() -> None: project=project, base_model=base_model, _internal_config=internal_config, - report_metrics=[], # Disable wandb logging ) await model.register(backend) openai_client = model.openai_client() + cost_calculator = model.cost_calculator async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory: """Core rollout logic used by both training and eval.""" @@ -241,6 +241,9 @@ async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory: tools=TOOLS, tool_choice=TOOL_CHOICE, ) + usage = getattr(response, "usage", None) + prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0) + completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0) choice = response.choices[0] raw_guess, source = extract_guess(choice) sampled_content = choice.message.content or "" @@ -259,6 +262,12 @@ async def do_rollout(scenario: Scenario, temp: float) -> art.Trajectory: "tool_call_found": 1.0 if source != "missing" else 0.0, "tool_call_structured": 1.0 if source == "tool_call" else 0.0, } + sample_costs = cost_calculator( + prompt_tokens, + completion_tokens, + ) + if sample_costs: + metrics.update(sample_costs) return art.Trajectory( messages_and_choices=[*messages, choice], tools=TOOLS, diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index bf7c355c2..a061636b5 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -563,21 +563,21 @@ async def _run_eval(self, step: int) -> None: reward: float | None = None try: result = await self.eval_fn(self.model, step, self.config) - splits: dict[str, list[art.Trajectory]] + splits: dict[str, list[art.Trajectory | art.TrajectoryGroup]] if isinstance(result, dict): splits = result else: splits = {"val": result} - val_trajectories = splits.get("val") or [] - if val_trajectories: - reward = sum(t.reward for t in val_trajectories) / len(val_trajectories) - else: - reward = None - - for split_name, trajectories in splits.items(): - if trajectories: - await self.model.log(trajectories, split=split_name, step=step) + for split_name, items in splits.items(): + groups, trajectories = self._normalize_eval_items(items) + if split_name == "val": + if trajectories: + reward = sum(t.reward for t in trajectories) / len(trajectories) + else: + reward = None + if groups: + await self.model.log(groups, split=split_name, step=step) except asyncio.CancelledError: raise except Exception as exc: @@ -585,6 +585,26 @@ async def _run_eval(self, step: int) -> None: finally: self._status.note_val_finished(step, reward) + @staticmethod + def _normalize_eval_items( + items: list[art.Trajectory | art.TrajectoryGroup], + ) -> tuple[list[TrajectoryGroup], list[art.Trajectory]]: + if not items: + return [], [] + groups: list[TrajectoryGroup] = [] + loose: list[art.Trajectory] = [] + for item in items: + if isinstance(item, TrajectoryGroup): + groups.append(item) + else: + loose.append(item) + if loose: + groups.append(TrajectoryGroup(loose)) + trajectories: list[art.Trajectory] = [] + for group in groups: + trajectories.extend(group.trajectories) + return groups, trajectories + def _apply_policy_versions( self, group: TrajectoryGroup, diff --git a/src/art/pipeline_trainer/types.py b/src/art/pipeline_trainer/types.py index 532acf9cd..4b04891e2 100644 --- a/src/art/pipeline_trainer/types.py +++ b/src/art/pipeline_trainer/types.py @@ -21,5 +21,8 @@ EvalFn = Callable[ [art.TrainableModel, int, ConfigT], - Awaitable[list[Trajectory] | dict[str, list[Trajectory]]], + Awaitable[ + list[Trajectory | TrajectoryGroup] + | dict[str, list[Trajectory | TrajectoryGroup]] + ], ] diff --git a/src/art/tinker/cookbook_v/__init__.py b/src/art/tinker/cookbook_v/__init__.py index e69de29bb..8b1378917 100644 --- a/src/art/tinker/cookbook_v/__init__.py +++ b/src/art/tinker/cookbook_v/__init__.py @@ -0,0 +1 @@ + diff --git a/src/art/tinker/cookbook_v/utils/__init__.py b/src/art/tinker/cookbook_v/utils/__init__.py index e69de29bb..8b1378917 100644 --- a/src/art/tinker/cookbook_v/utils/__init__.py +++ b/src/art/tinker/cookbook_v/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index 7cb568c28..5b122a5d5 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -29,6 +29,7 @@ from .. import dev from ..backend import Backend +from ..costs import build_cost_calculator, compute_train_cost, get_model_pricing from ..model import Model, TrainableModel from ..tinker.backend import get_renderer_name from ..tinker.server import get_free_port @@ -159,6 +160,9 @@ async def register(self, model: Model) -> None: if not model.trainable: return trainable_model = cast(TrainableModel, model) + pricing = get_model_pricing(trainable_model.base_model) + if pricing is not None: + trainable_model.set_cost_calculator(build_cost_calculator(pricing)) state = await self._build_model_state(trainable_model) self._model_state[model.name] = state @@ -220,6 +224,14 @@ async def train( # type: ignore[override] if not datums: return TrainResult(step=state.current_step, metrics=metrics) + train_tokens = 0 + for datum in datums: + train_tokens += len(datum.model_input.to_ints()) + metrics["train_tokens"] = float(train_tokens) + pricing = get_model_pricing(model.base_model) + if pricing is not None: + metrics["costs_train"] = compute_train_cost(train_tokens, pricing) + if adam_params is None: adam_params = tinker.AdamParams( learning_rate=learning_rate,