Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions src/art/costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Cost utilities for ART training and evaluation."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, TypeAlias


@dataclass(frozen=True)
class ModelPricing:
"""Per-million-token pricing for a model."""

prefill: float # $/1M tokens for prompt/prefill
sample: float # $/1M tokens for sampling/generation
train: float # $/1M tokens for training


TokenCount: TypeAlias = int | None
CostCalculator: TypeAlias = Callable[[TokenCount, TokenCount], dict[str, float]]

# Pricing per model ($/1M tokens). Keep in sync with infra pricing.
MODEL_PRICING: dict[str, ModelPricing] = {
# Qwen models
"Qwen/Qwen3-4B-Instruct-2507": ModelPricing(prefill=0.07, sample=0.22, train=0.22),
"Qwen/Qwen3-8B": ModelPricing(prefill=0.13, sample=0.40, train=0.40),
"Qwen/Qwen3-8B-Base": ModelPricing(prefill=0.13, sample=0.40, train=0.40),
"Qwen/Qwen3-30B-A3B": ModelPricing(prefill=0.12, sample=0.30, train=0.36),
"Qwen/Qwen3-30B-A3B-Base": ModelPricing(prefill=0.12, sample=0.30, train=0.36),
"Qwen/Qwen3-30B-A3B-Instruct-2507": ModelPricing(
prefill=0.12, sample=0.30, train=0.36
),
"Qwen/Qwen3-32B": ModelPricing(prefill=0.49, sample=1.47, train=1.47),
"Qwen/Qwen3-235B-A22B-Instruct-2507": ModelPricing(
prefill=0.68, sample=1.70, train=2.04
),
"Qwen/Qwen3-VL-30B-A3B-Instruct": ModelPricing(
prefill=0.18, sample=0.44, train=0.53
),
"Qwen/Qwen3-VL-235B-A22B-Instruct": ModelPricing(
prefill=1.02, sample=2.56, train=3.07
),
# Meta Llama models
"meta-llama/Llama-3.2-1B": ModelPricing(prefill=0.03, sample=0.09, train=0.09),
"meta-llama/Llama-3.2-3B": ModelPricing(prefill=0.06, sample=0.18, train=0.18),
"meta-llama/Llama-3.1-8B": ModelPricing(prefill=0.13, sample=0.40, train=0.40),
"meta-llama/Llama-3.1-8B-Instruct": ModelPricing(
prefill=0.13, sample=0.40, train=0.40
),
"meta-llama/Llama-3.1-70B": ModelPricing(prefill=1.05, sample=3.16, train=3.16),
"meta-llama/Llama-3.3-70B-Instruct": ModelPricing(
prefill=1.05, sample=3.16, train=3.16
),
# DeepSeek models
"deepseek-ai/DeepSeek-V3.1": ModelPricing(prefill=1.13, sample=2.81, train=3.38),
"deepseek-ai/DeepSeek-V3.1-Base": ModelPricing(
prefill=1.13, sample=2.81, train=3.38
),
# OpenAI models
"openai/gpt-oss-120b": ModelPricing(prefill=0.18, sample=0.44, train=0.52),
"openai/gpt-oss-20b": ModelPricing(prefill=0.12, sample=0.30, train=0.36),
# Moonshot models
"moonshotai/Kimi-K2-Thinking": ModelPricing(prefill=0.98, sample=2.44, train=2.93),
}


def get_model_pricing(
model_name: str | None, *, strict: bool = False
) -> ModelPricing | None:
"""Return pricing for a model or None if missing."""
if model_name is None:
return None
pricing = MODEL_PRICING.get(model_name)
if pricing is None and strict:
raise ValueError(
f"No pricing configured for model '{model_name}'. "
f"Add pricing to art.costs.MODEL_PRICING. "
f"Available models: {list(MODEL_PRICING.keys())}"
)
return pricing


def tokens_to_cost(num_tokens: float, price_per_million: float) -> float:
"""Convert token count to cost in dollars."""
return float(num_tokens) * price_per_million / 1_000_000


def compute_sample_costs(
*,
prompt_tokens: int | None,
completion_tokens: int | None,
pricing: ModelPricing,
) -> dict[str, float]:
"""Compute prompt+completion costs for a single API call."""
prompt_value = float(prompt_tokens or 0)
completion_value = float(completion_tokens or 0)
prefill_cost = tokens_to_cost(prompt_value, pricing.prefill)
sample_cost = tokens_to_cost(completion_value, pricing.sample)
return {
"costs_prefill": prefill_cost,
"costs_sample": sample_cost,
}


def build_cost_calculator(pricing: ModelPricing) -> CostCalculator:
"""Return a callable that computes prompt+completion costs for a request."""

def _calculator(
prompt_tokens: int | None, completion_tokens: int | None
) -> dict[str, float]:
return compute_sample_costs(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
pricing=pricing,
)

return _calculator


def compute_train_cost(train_tokens: float, pricing: ModelPricing) -> float:
"""Compute training cost from token count."""
return tokens_to_cost(train_tokens, pricing.train)
172 changes: 169 additions & 3 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime
import json
import os
Expand All @@ -11,6 +12,7 @@
from typing_extensions import Never, TypeVar

from . import dev
from .costs import CostCalculator
from .trajectories import Trajectory, TrajectoryGroup
from .types import TrainConfig
from .utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev
Expand All @@ -25,6 +27,10 @@
ModelConfig = TypeVar("ModelConfig", bound=BaseModel | None)
StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any])

COSTS_STATE_KEY = "_costs"
COSTS_METRIC_PREFIX = "costs_"
COSTS_TOTAL_KEY = f"{COSTS_METRIC_PREFIX}total"


class Model(
BaseModel,
Expand Down Expand Up @@ -87,6 +93,8 @@ class Model(
_s3_prefix: str | None = None
_openai_client: AsyncOpenAI | None = None
_wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization
_costs_lock: asyncio.Lock
_cost_calculator: CostCalculator

def __init__(
self,
Expand Down Expand Up @@ -374,6 +382,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
wandb.define_metric("training_step")
wandb.define_metric("train/*", step_metric="training_step")
wandb.define_metric("val/*", step_metric="training_step")
wandb.define_metric("costs/*", step_metric="training_step")
return self._wandb_run

def _log_metrics(
Expand Down Expand Up @@ -406,6 +415,64 @@ def _log_metrics(
if run := self._get_wandb_run():
run.log({"training_step": step, **prefixed})

async def _record_costs(
self,
split: str,
step: int,
*,
cost_components: dict[str, float],
cost_total_direct: float,
cost_seen: bool,
) -> None:
component_total = sum(cost_components.values())
step_total = component_total if component_total > 0 else cost_total_direct
if not cost_seen or step_total <= 0:
return

async with self._costs_lock:
existing_state = self.read_state() or {}
raw_costs = existing_state.get(COSTS_STATE_KEY) or {}
cumulative = {
key: float(value)
for key, value in raw_costs.items()
if isinstance(value, (int, float))
}
last_steps = raw_costs.get("_last_steps")
if not isinstance(last_steps, dict):
last_steps = {}
last_step = last_steps.get(split)

if isinstance(last_step, (int, float)) and int(last_step) >= step:
for component, value in cost_components.items():
if value == 0:
continue
cumulative_key = f"{split}_{component}"
cumulative[cumulative_key] = max(
cumulative.get(cumulative_key, 0.0), value
)
cumulative[split] = max(cumulative.get(split, 0.0), step_total)
cumulative["total"] = max(
cumulative.get("total", 0.0), cumulative.get(split, 0.0)
)
self.merge_state(
{COSTS_STATE_KEY: {**cumulative, "_last_steps": last_steps}}
)
self._log_metrics(cumulative, "costs", step)
return

for component, value in cost_components.items():
if value == 0:
continue
cumulative_key = f"{split}_{component}"
cumulative[cumulative_key] = cumulative.get(cumulative_key, 0.0) + value
cumulative[split] = cumulative.get(split, 0.0) + step_total
cumulative["total"] = cumulative.get("total", 0.0) + step_total
last_steps[split] = step
self.merge_state(
{COSTS_STATE_KEY: {**cumulative, "_last_steps": last_steps}}
)
self._log_metrics(cumulative, "costs", step)

async def log(
self,
trajectories: (
Expand Down Expand Up @@ -439,7 +506,42 @@ async def log(
# If only metrics provided (no trajectories), just log them and return
if trajectories is None:
if metrics is not None:
self._log_metrics(metrics, split, step)
cost_step = await self.get_step()
cost_components: dict[str, float] = {}
cost_total_direct = 0.0
cost_seen = False

for metric, value in metrics.items():
if not isinstance(value, (int, float)):
continue
if metric == COSTS_TOTAL_KEY:
raise ValueError(
"Do not log 'costs_total' directly. Log costs_* components "
"(e.g., costs_prefill, costs_sample) and totals are derived."
)
elif metric.startswith(COSTS_METRIC_PREFIX):
component = metric[len(COSTS_METRIC_PREFIX) :]
if component:
cost_components[component] = cost_components.get(
component, 0.0
) + float(value)
cost_seen = True

metrics_without_costs = {
key: value
for key, value in metrics.items()
if not key.startswith(COSTS_METRIC_PREFIX)
}
if metrics_without_costs:
self._log_metrics(metrics_without_costs, split, step)

await self._record_costs(
split,
cost_step,
cost_components=cost_components,
cost_total_direct=cost_total_direct,
cost_seen=cost_seen,
)
return

# Convert to list[TrajectoryGroup]
Expand All @@ -465,13 +567,39 @@ async def log(
trajectory_groups, f"{trajectories_dir}/{file_name}"
)

# 2. Calculate aggregate metrics
# 2. Calculate aggregate metrics (excluding additive costs)
cost_step = await self.get_step()
all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []}
group_metrics: dict[str, list[float]] = {}
cost_components: dict[str, float] = {}
cost_total_direct = 0.0
cost_seen = False

def _add_costs(metrics_dict: dict[str, float | int | bool]) -> None:
nonlocal cost_total_direct, cost_seen
for metric, value in metrics_dict.items():
if not isinstance(value, (int, float)):
continue
if metric == COSTS_TOTAL_KEY:
raise ValueError(
"Do not log 'costs_total' directly. Log costs_* components "
"(e.g., costs_prefill, costs_sample) and totals are derived."
)
elif metric.startswith(COSTS_METRIC_PREFIX):
component = metric[len(COSTS_METRIC_PREFIX) :]
if component:
cost_components[component] = cost_components.get(
component, 0.0
) + float(value)
cost_seen = True

for group in trajectory_groups:
if group.metrics:
_add_costs(group.metrics)
if group.trajectories:
for metric, value in group.metrics.items():
if metric.startswith(COSTS_METRIC_PREFIX):
continue
if metric not in group_metrics:
group_metrics[metric] = []
group_metrics[metric].append(float(value))
Expand All @@ -486,9 +614,13 @@ async def log(

# Collect other custom metrics
for metric, value in trajectory.metrics.items():
if metric.startswith(COSTS_METRIC_PREFIX):
continue
if metric not in all_metrics:
all_metrics[metric] = []
all_metrics[metric].append(float(value))
if trajectory.metrics:
_add_costs(trajectory.metrics)

# Calculate averages for all metrics
averages: dict[str, float] = {}
Expand All @@ -506,11 +638,26 @@ async def log(

# Merge in any additional metrics passed directly
if metrics is not None:
averages.update(metrics)
_add_costs(metrics)
metrics_without_costs = {
key: value
for key, value in metrics.items()
if not key.startswith(COSTS_METRIC_PREFIX)
}
averages.update(metrics_without_costs)

# 3. Log metrics (writes to history.jsonl and wandb)
self._log_metrics(averages, split, step)

# 4. Log cumulative costs (additive)
await self._record_costs(
split,
cost_step,
cost_components=cost_components,
cost_total_direct=cost_total_direct,
cost_seen=cost_seen,
)

async def get_step(self) -> int:
"""
Get the model's current training step. For non-trainable models, returns 0.
Expand Down Expand Up @@ -559,6 +706,25 @@ def __init__(
report_metrics=report_metrics,
**kwargs,
)
object.__setattr__(self, "_costs_lock", asyncio.Lock())
object.__setattr__(self, "_cost_calculator", self._noop_cost_calculator)

@property
def cost_calculator(self) -> CostCalculator:
return self._cost_calculator

def set_cost_calculator(self, calculator: CostCalculator | None) -> None:
object.__setattr__(
self,
"_cost_calculator",
calculator if calculator is not None else self._noop_cost_calculator,
)

@staticmethod
def _noop_cost_calculator(
_prompt_tokens: int | None, _completion_tokens: int | None
) -> dict[str, float]:
return {}
if _internal_config is not None:
# Bypass BaseModel __setattr__ to allow setting private attr
object.__setattr__(self, "_internal_config", _internal_config)
Expand Down
Loading