Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,14 @@ async def log(

# 2. Calculate aggregate metrics
all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []}
group_metrics: dict[str, list[float]] = {}

for group in trajectory_groups:
if group.trajectories:
for metric, value in group.metrics.items():
if metric not in group_metrics:
group_metrics[metric] = []
group_metrics[metric].append(float(value))
for trajectory in group:
if isinstance(trajectory, BaseException):
all_metrics["exception_rate"].append(1)
Expand All @@ -490,6 +496,11 @@ async def log(
if len(values) > 0:
averages[metric] = sum(values) / len(values)

# Aggregate group-level metrics once per group
for metric, values in group_metrics.items():
if len(values) > 0:
averages[f"group_metric_{metric}"] = sum(values) / len(values)

# Calculate average standard deviation of rewards within groups
averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups)

Expand Down
3 changes: 1 addition & 2 deletions src/art/pipeline_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,7 @@ def _apply_scenario_metadata(
continue
if not self._is_scalar_metadata(value):
continue
for trajectory in group.trajectories:
trajectory.metadata[f"scenario_{key}"] = value
group.metadata[f"scenario_{key}"] = value

def _is_group_stale(self, group: TrajectoryGroup, min_version: int) -> bool:
group_version = self._group_initial_version(group)
Expand Down
44 changes: 42 additions & 2 deletions src/art/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
class TrajectoryGroup(pydantic.BaseModel):
trajectories: list[Trajectory]
exceptions: list[PydanticException] = []
metadata: dict[str, MetadataValue] = {}
metrics: dict[str, float | int | bool] = {}
logs: list[str] = []

def __init__(
self,
Expand All @@ -139,6 +142,9 @@ def __init__(
),
*,
exceptions: list[BaseException] = [],
metadata: dict[str, MetadataValue] | None = None,
metrics: dict[str, float | int | bool] | None = None,
logs: list[str] | None = None,
) -> None:
super().__init__(
trajectories=[
Expand Down Expand Up @@ -166,6 +172,11 @@ def __init__(
+ exceptions
)
],
metadata=metadata
if metadata is not None
else getattr(self, "metadata", {}),
metrics=metrics if metrics is not None else getattr(self, "metrics", {}),
logs=logs if logs is not None else getattr(self, "logs", []),
)

def __copy__(self):
Expand All @@ -176,6 +187,9 @@ def __copy__(self):
new_instance = self.__class__(
trajectories=self.trajectories[:], # Shallow copy of list
exceptions=[], # Will be set below
metadata=self.metadata.copy(),
metrics=self.metrics.copy(),
logs=self.logs[:],
)
# Manually copy exceptions since they're PydanticException objects
new_instance.exceptions = self.exceptions[:]
Expand All @@ -197,13 +211,19 @@ def __deepcopy__(self, memo: dict[int, Any] | None = None):
new_instance = self.__class__(
trajectories=copy.deepcopy(self.trajectories, memo),
exceptions=[], # Will be set below
metadata=copy.deepcopy(self.metadata, memo),
metrics=copy.deepcopy(self.metrics, memo),
logs=copy.deepcopy(self.logs, memo),
)
# Register in memo before deep copying attributes to handle circular refs
memo[id(self)] = new_instance
# Deep copy exceptions
new_instance.exceptions = copy.deepcopy(self.exceptions, memo)
return new_instance

def log(self, message: str) -> None:
self.logs.append(message)

def __iter__(self) -> Iterator[Trajectory]: # type: ignore[override]
return iter(self.trajectories)

Expand All @@ -216,6 +236,9 @@ def __new__(
trajectories: Iterable[Trajectory | BaseException],
*,
exceptions: list[BaseException] = [],
metadata: dict[str, MetadataValue] | None = None,
metrics: dict[str, float | int | bool] | None = None,
logs: list[str] | None = None,
) -> "TrajectoryGroup": ...

@overload
Expand All @@ -224,6 +247,9 @@ def __new__(
trajectories: Iterable[Awaitable[Trajectory]],
*,
exceptions: list[BaseException] = [],
metadata: dict[str, MetadataValue] | None = None,
metrics: dict[str, float | int | bool] | None = None,
logs: list[str] | None = None,
) -> Awaitable["TrajectoryGroup"]: ...

def __new__(
Expand All @@ -233,11 +259,19 @@ def __new__(
),
*,
exceptions: list[BaseException] = [],
metadata: dict[str, MetadataValue] | None = None,
metrics: dict[str, float | int | bool] | None = None,
logs: list[str] | None = None,
) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]":
ts = list(trajectories)
if any(hasattr(t, "__await__") for t in ts):

async def _(exceptions: list[BaseException]):
async def _(
exceptions: list[BaseException],
metadata: dict[str, MetadataValue] | None,
metrics: dict[str, float | int | bool] | None,
logs: list[str] | None,
):
from .gather import get_gather_context, record_metrics

context = get_gather_context()
Expand All @@ -259,6 +293,9 @@ async def _(exceptions: list[BaseException]):
return TrajectoryGroup(
trajectories=trajectories,
exceptions=exceptions,
metadata=metadata,
metrics=metrics,
logs=logs,
)

class CoroutineWithMetadata:
Expand All @@ -269,12 +306,15 @@ def __init__(self, coro, num_trajectories):
def __await__(self):
return self.coro.__await__()

coro = _(exceptions.copy())
coro = _(exceptions.copy(), metadata, metrics, logs)
return CoroutineWithMetadata(coro, len(ts))
else:
group = super().__new__(cls)
group.__init__(
trajectories=cast(list[Trajectory | BaseException], ts),
exceptions=exceptions,
metadata=metadata,
metrics=metrics,
logs=logs,
)
return group
34 changes: 34 additions & 0 deletions src/art/utils/benchmarking/load_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ async def load_trajectories(
One column for every distinct metric key found in the dataset.
metadata_* : str
One column for every distinct metadata key.
group_metric_* : float
One column for every distinct group-level metric key.
group_metadata_* : str
One column for every distinct group-level metadata key.

Parameters
----------
Expand Down Expand Up @@ -144,6 +148,8 @@ async def load_trajectories(
rows: list[dict] = []
metric_cols: set[str] = set()
metadata_cols: set[str] = set()
group_metric_cols: set[str] = set()
group_metadata_cols: set[str] = set()
# Map (model, split, step, group_index) -> unique group_number
group_key_to_number: dict[tuple[str, str, int, int], int] = {}
next_group_number = 1
Expand Down Expand Up @@ -195,11 +201,35 @@ async def load_trajectories(
except (json.JSONDecodeError, TypeError):
pass

# Parse group metrics from JSON (duplicated across group rows)
group_metrics = {}
if row_dict.get("group_metrics"):
try:
group_metrics = json.loads(row_dict["group_metrics"])
except (json.JSONDecodeError, TypeError):
pass

# Parse group metadata from JSON (duplicated across group rows)
group_metadata = {}
if row_dict.get("group_metadata"):
try:
group_metadata = json.loads(row_dict["group_metadata"])
except (json.JSONDecodeError, TypeError):
pass

# Prepare metrics and metadata columns
prepped_metrics = {f"metric_{k}": v for k, v in metrics.items()}
prepped_metadata = {f"metadata_{k}": str(v) for k, v in metadata.items()}
prepped_group_metrics = {
f"group_metric_{k}": v for k, v in group_metrics.items()
}
prepped_group_metadata = {
f"group_metadata_{k}": str(v) for k, v in group_metadata.items()
}
metric_cols.update(prepped_metrics.keys())
metadata_cols.update(prepped_metadata.keys())
group_metric_cols.update(prepped_group_metrics.keys())
group_metadata_cols.update(prepped_group_metadata.keys())

# Process messages
messages = []
Expand Down Expand Up @@ -255,6 +285,8 @@ async def load_trajectories(
"logs": row_dict.get("logs"),
**prepped_metrics,
**prepped_metadata,
**prepped_group_metrics,
**prepped_group_metadata,
}

rows.append(row_data)
Expand Down Expand Up @@ -295,6 +327,8 @@ async def load_trajectories(
}
| {k: pl.Float64 for k in metric_cols}
| {k: pl.Utf8 for k in metadata_cols}
| {k: pl.Float64 for k in group_metric_cols}
| {k: pl.Utf8 for k in group_metadata_cols}
)

return pl.DataFrame(rows, schema=schema)
54 changes: 52 additions & 2 deletions src/art/utils/trajectory_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import json
from pathlib import Path
from typing import Any
from typing import Any, cast

from litellm.types.utils import Choices
from openai.types.chat.chat_completion import Choice
Expand Down Expand Up @@ -72,6 +72,9 @@ def write_trajectory_groups_parquet(

rows = []
for group_index, group in enumerate(trajectory_groups):
group_metadata = json.dumps(group.metadata) if group.metadata else None
group_metrics = json.dumps(group.metrics) if group.metrics else None
group_logs = group.logs if group.logs else None
for trajectory in group.trajectories:
if not isinstance(trajectory, Trajectory):
continue
Expand All @@ -96,6 +99,9 @@ def write_trajectory_groups_parquet(
rows.append(
{
"group_index": group_index,
"group_metadata": group_metadata,
"group_metrics": group_metrics,
"group_logs": group_logs,
"reward": trajectory.reward,
"metrics": json.dumps(trajectory.metrics)
if trajectory.metrics
Expand Down Expand Up @@ -123,6 +129,9 @@ def write_trajectory_groups_parquet(
schema = pa.schema(
[
("group_index", pa.int64()),
("group_metadata", pa.string()),
("group_metrics", pa.string()),
("group_logs", pa.list_(pa.string())),
("reward", pa.float64()),
("metrics", pa.string()),
("metadata", pa.string()),
Expand Down Expand Up @@ -158,6 +167,23 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]:
columns = [desc[0] for desc in con.description]

groups_dict: dict[int, list[Trajectory]] = {}
group_metadata_by_index: dict[int, dict[str, Any]] = {}
group_metrics_by_index: dict[int, dict[str, Any]] = {}
group_logs_by_index: dict[int, list[str]] = {}

def _load_json_payload(payload: object | None) -> dict[str, Any]:
if payload is None:
return {}
if isinstance(payload, dict):
return cast(dict[str, Any], payload)
if isinstance(payload, (str, bytes, bytearray)):
if not payload:
return {}
try:
return json.loads(payload)
except (json.JSONDecodeError, TypeError):
return {}
return {}

for row in rows:
row_dict = dict(zip(columns, row))
Expand All @@ -166,6 +192,24 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]:
continue

group_index = row_dict.get("group_index", 0)
if group_index not in group_metadata_by_index:
group_metadata_by_index[group_index] = _load_json_payload(
row_dict.get("group_metadata")
)
if group_index not in group_metrics_by_index:
group_metrics_by_index[group_index] = _load_json_payload(
row_dict.get("group_metrics")
)
if group_index not in group_logs_by_index:
raw_group_logs = row_dict.get("group_logs")
if isinstance(raw_group_logs, (list, tuple)):
group_logs_by_index[group_index] = [
str(item) for item in raw_group_logs
]
elif raw_group_logs is None:
group_logs_by_index[group_index] = []
else:
group_logs_by_index[group_index] = [str(raw_group_logs)]

# Convert messages
messages_and_choices = []
Expand Down Expand Up @@ -196,6 +240,12 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]:
groups_dict[group_index].append(trajectory)

return [
TrajectoryGroup(trajectories=groups_dict[idx], exceptions=[])
TrajectoryGroup(
trajectories=groups_dict[idx],
exceptions=[],
metadata=group_metadata_by_index.get(idx, {}),
metrics=group_metrics_by_index.get(idx, {}),
logs=group_logs_by_index.get(idx, []),
)
for idx in sorted(groups_dict.keys())
]
40 changes: 40 additions & 0 deletions tests/unit/test_benchmarking_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest

from art import Trajectory, TrajectoryGroup
from art.utils.benchmarking.load_trajectories import load_trajectories
from art.utils.trajectory_logging import write_trajectory_groups_parquet


@pytest.mark.asyncio
async def test_load_trajectories_group_columns(tmp_path):
project_name = "proj"
model_name = "model"
traj_dir = tmp_path / project_name / "models" / model_name / "trajectories" / "val"
traj_dir.mkdir(parents=True)

groups = [
TrajectoryGroup(
trajectories=[
Trajectory(
reward=1.0,
messages_and_choices=[{"role": "user", "content": "hi"}],
)
],
metadata={"scenario_id": "abc"},
metrics={"judge_score": 0.9},
logs=["group log"],
exceptions=[],
)
]
write_trajectory_groups_parquet(groups, traj_dir / "0000.parquet")

df = await load_trajectories(
project_name=project_name,
models=[model_name],
art_path=str(tmp_path),
)

assert "group_metric_judge_score" in df.columns
assert "group_metadata_scenario_id" in df.columns
assert df["group_metric_judge_score"][0] == 0.9
assert df["group_metadata_scenario_id"][0] == "abc"
Loading