diff --git a/src/art/model.py b/src/art/model.py index 05f55bd67..e503c6f6a 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -467,8 +467,14 @@ async def log( # 2. Calculate aggregate metrics all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []} + group_metrics: dict[str, list[float]] = {} for group in trajectory_groups: + if group.trajectories: + for metric, value in group.metrics.items(): + if metric not in group_metrics: + group_metrics[metric] = [] + group_metrics[metric].append(float(value)) for trajectory in group: if isinstance(trajectory, BaseException): all_metrics["exception_rate"].append(1) @@ -490,6 +496,11 @@ async def log( if len(values) > 0: averages[metric] = sum(values) / len(values) + # Aggregate group-level metrics once per group + for metric, values in group_metrics.items(): + if len(values) > 0: + averages[f"group_metric_{metric}"] = sum(values) / len(values) + # Calculate average standard deviation of rewards within groups averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 4bb062d4c..bf7c355c2 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -610,8 +610,7 @@ def _apply_scenario_metadata( continue if not self._is_scalar_metadata(value): continue - for trajectory in group.trajectories: - trajectory.metadata[f"scenario_{key}"] = value + group.metadata[f"scenario_{key}"] = value def _is_group_stale(self, group: TrajectoryGroup, min_version: int) -> bool: group_version = self._group_initial_version(group) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index a04762463..69dd2d3de 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -131,6 +131,9 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages: class TrajectoryGroup(pydantic.BaseModel): trajectories: list[Trajectory] exceptions: list[PydanticException] = [] + metadata: dict[str, MetadataValue] = {} + metrics: dict[str, float | int | bool] = {} + logs: list[str] = [] def __init__( self, @@ -139,6 +142,9 @@ def __init__( ), *, exceptions: list[BaseException] = [], + metadata: dict[str, MetadataValue] | None = None, + metrics: dict[str, float | int | bool] | None = None, + logs: list[str] | None = None, ) -> None: super().__init__( trajectories=[ @@ -166,6 +172,11 @@ def __init__( + exceptions ) ], + metadata=metadata + if metadata is not None + else getattr(self, "metadata", {}), + metrics=metrics if metrics is not None else getattr(self, "metrics", {}), + logs=logs if logs is not None else getattr(self, "logs", []), ) def __copy__(self): @@ -176,6 +187,9 @@ def __copy__(self): new_instance = self.__class__( trajectories=self.trajectories[:], # Shallow copy of list exceptions=[], # Will be set below + metadata=self.metadata.copy(), + metrics=self.metrics.copy(), + logs=self.logs[:], ) # Manually copy exceptions since they're PydanticException objects new_instance.exceptions = self.exceptions[:] @@ -197,6 +211,9 @@ def __deepcopy__(self, memo: dict[int, Any] | None = None): new_instance = self.__class__( trajectories=copy.deepcopy(self.trajectories, memo), exceptions=[], # Will be set below + metadata=copy.deepcopy(self.metadata, memo), + metrics=copy.deepcopy(self.metrics, memo), + logs=copy.deepcopy(self.logs, memo), ) # Register in memo before deep copying attributes to handle circular refs memo[id(self)] = new_instance @@ -204,6 +221,9 @@ def __deepcopy__(self, memo: dict[int, Any] | None = None): new_instance.exceptions = copy.deepcopy(self.exceptions, memo) return new_instance + def log(self, message: str) -> None: + self.logs.append(message) + def __iter__(self) -> Iterator[Trajectory]: # type: ignore[override] return iter(self.trajectories) @@ -216,6 +236,9 @@ def __new__( trajectories: Iterable[Trajectory | BaseException], *, exceptions: list[BaseException] = [], + metadata: dict[str, MetadataValue] | None = None, + metrics: dict[str, float | int | bool] | None = None, + logs: list[str] | None = None, ) -> "TrajectoryGroup": ... @overload @@ -224,6 +247,9 @@ def __new__( trajectories: Iterable[Awaitable[Trajectory]], *, exceptions: list[BaseException] = [], + metadata: dict[str, MetadataValue] | None = None, + metrics: dict[str, float | int | bool] | None = None, + logs: list[str] | None = None, ) -> Awaitable["TrajectoryGroup"]: ... def __new__( @@ -233,11 +259,19 @@ def __new__( ), *, exceptions: list[BaseException] = [], + metadata: dict[str, MetadataValue] | None = None, + metrics: dict[str, float | int | bool] | None = None, + logs: list[str] | None = None, ) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]": ts = list(trajectories) if any(hasattr(t, "__await__") for t in ts): - async def _(exceptions: list[BaseException]): + async def _( + exceptions: list[BaseException], + metadata: dict[str, MetadataValue] | None, + metrics: dict[str, float | int | bool] | None, + logs: list[str] | None, + ): from .gather import get_gather_context, record_metrics context = get_gather_context() @@ -259,6 +293,9 @@ async def _(exceptions: list[BaseException]): return TrajectoryGroup( trajectories=trajectories, exceptions=exceptions, + metadata=metadata, + metrics=metrics, + logs=logs, ) class CoroutineWithMetadata: @@ -269,12 +306,15 @@ def __init__(self, coro, num_trajectories): def __await__(self): return self.coro.__await__() - coro = _(exceptions.copy()) + coro = _(exceptions.copy(), metadata, metrics, logs) return CoroutineWithMetadata(coro, len(ts)) else: group = super().__new__(cls) group.__init__( trajectories=cast(list[Trajectory | BaseException], ts), exceptions=exceptions, + metadata=metadata, + metrics=metrics, + logs=logs, ) return group diff --git a/src/art/utils/benchmarking/load_trajectories.py b/src/art/utils/benchmarking/load_trajectories.py index 2b494ffe4..7be80239d 100644 --- a/src/art/utils/benchmarking/load_trajectories.py +++ b/src/art/utils/benchmarking/load_trajectories.py @@ -69,6 +69,10 @@ async def load_trajectories( One column for every distinct metric key found in the dataset. metadata_* : str One column for every distinct metadata key. + group_metric_* : float + One column for every distinct group-level metric key. + group_metadata_* : str + One column for every distinct group-level metadata key. Parameters ---------- @@ -144,6 +148,8 @@ async def load_trajectories( rows: list[dict] = [] metric_cols: set[str] = set() metadata_cols: set[str] = set() + group_metric_cols: set[str] = set() + group_metadata_cols: set[str] = set() # Map (model, split, step, group_index) -> unique group_number group_key_to_number: dict[tuple[str, str, int, int], int] = {} next_group_number = 1 @@ -195,11 +201,35 @@ async def load_trajectories( except (json.JSONDecodeError, TypeError): pass + # Parse group metrics from JSON (duplicated across group rows) + group_metrics = {} + if row_dict.get("group_metrics"): + try: + group_metrics = json.loads(row_dict["group_metrics"]) + except (json.JSONDecodeError, TypeError): + pass + + # Parse group metadata from JSON (duplicated across group rows) + group_metadata = {} + if row_dict.get("group_metadata"): + try: + group_metadata = json.loads(row_dict["group_metadata"]) + except (json.JSONDecodeError, TypeError): + pass + # Prepare metrics and metadata columns prepped_metrics = {f"metric_{k}": v for k, v in metrics.items()} prepped_metadata = {f"metadata_{k}": str(v) for k, v in metadata.items()} + prepped_group_metrics = { + f"group_metric_{k}": v for k, v in group_metrics.items() + } + prepped_group_metadata = { + f"group_metadata_{k}": str(v) for k, v in group_metadata.items() + } metric_cols.update(prepped_metrics.keys()) metadata_cols.update(prepped_metadata.keys()) + group_metric_cols.update(prepped_group_metrics.keys()) + group_metadata_cols.update(prepped_group_metadata.keys()) # Process messages messages = [] @@ -255,6 +285,8 @@ async def load_trajectories( "logs": row_dict.get("logs"), **prepped_metrics, **prepped_metadata, + **prepped_group_metrics, + **prepped_group_metadata, } rows.append(row_data) @@ -295,6 +327,8 @@ async def load_trajectories( } | {k: pl.Float64 for k in metric_cols} | {k: pl.Utf8 for k in metadata_cols} + | {k: pl.Float64 for k in group_metric_cols} + | {k: pl.Utf8 for k in group_metadata_cols} ) return pl.DataFrame(rows, schema=schema) diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py index 77ddb7167..481c7c8c1 100644 --- a/src/art/utils/trajectory_logging.py +++ b/src/art/utils/trajectory_logging.py @@ -9,7 +9,7 @@ import json from pathlib import Path -from typing import Any +from typing import Any, cast from litellm.types.utils import Choices from openai.types.chat.chat_completion import Choice @@ -72,6 +72,9 @@ def write_trajectory_groups_parquet( rows = [] for group_index, group in enumerate(trajectory_groups): + group_metadata = json.dumps(group.metadata) if group.metadata else None + group_metrics = json.dumps(group.metrics) if group.metrics else None + group_logs = group.logs if group.logs else None for trajectory in group.trajectories: if not isinstance(trajectory, Trajectory): continue @@ -96,6 +99,9 @@ def write_trajectory_groups_parquet( rows.append( { "group_index": group_index, + "group_metadata": group_metadata, + "group_metrics": group_metrics, + "group_logs": group_logs, "reward": trajectory.reward, "metrics": json.dumps(trajectory.metrics) if trajectory.metrics @@ -123,6 +129,9 @@ def write_trajectory_groups_parquet( schema = pa.schema( [ ("group_index", pa.int64()), + ("group_metadata", pa.string()), + ("group_metrics", pa.string()), + ("group_logs", pa.list_(pa.string())), ("reward", pa.float64()), ("metrics", pa.string()), ("metadata", pa.string()), @@ -158,6 +167,23 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]: columns = [desc[0] for desc in con.description] groups_dict: dict[int, list[Trajectory]] = {} + group_metadata_by_index: dict[int, dict[str, Any]] = {} + group_metrics_by_index: dict[int, dict[str, Any]] = {} + group_logs_by_index: dict[int, list[str]] = {} + + def _load_json_payload(payload: object | None) -> dict[str, Any]: + if payload is None: + return {} + if isinstance(payload, dict): + return cast(dict[str, Any], payload) + if isinstance(payload, (str, bytes, bytearray)): + if not payload: + return {} + try: + return json.loads(payload) + except (json.JSONDecodeError, TypeError): + return {} + return {} for row in rows: row_dict = dict(zip(columns, row)) @@ -166,6 +192,24 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]: continue group_index = row_dict.get("group_index", 0) + if group_index not in group_metadata_by_index: + group_metadata_by_index[group_index] = _load_json_payload( + row_dict.get("group_metadata") + ) + if group_index not in group_metrics_by_index: + group_metrics_by_index[group_index] = _load_json_payload( + row_dict.get("group_metrics") + ) + if group_index not in group_logs_by_index: + raw_group_logs = row_dict.get("group_logs") + if isinstance(raw_group_logs, (list, tuple)): + group_logs_by_index[group_index] = [ + str(item) for item in raw_group_logs + ] + elif raw_group_logs is None: + group_logs_by_index[group_index] = [] + else: + group_logs_by_index[group_index] = [str(raw_group_logs)] # Convert messages messages_and_choices = [] @@ -196,6 +240,12 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]: groups_dict[group_index].append(trajectory) return [ - TrajectoryGroup(trajectories=groups_dict[idx], exceptions=[]) + TrajectoryGroup( + trajectories=groups_dict[idx], + exceptions=[], + metadata=group_metadata_by_index.get(idx, {}), + metrics=group_metrics_by_index.get(idx, {}), + logs=group_logs_by_index.get(idx, []), + ) for idx in sorted(groups_dict.keys()) ] diff --git a/tests/unit/test_benchmarking_loader.py b/tests/unit/test_benchmarking_loader.py new file mode 100644 index 000000000..defdb0e89 --- /dev/null +++ b/tests/unit/test_benchmarking_loader.py @@ -0,0 +1,40 @@ +import pytest + +from art import Trajectory, TrajectoryGroup +from art.utils.benchmarking.load_trajectories import load_trajectories +from art.utils.trajectory_logging import write_trajectory_groups_parquet + + +@pytest.mark.asyncio +async def test_load_trajectories_group_columns(tmp_path): + project_name = "proj" + model_name = "model" + traj_dir = tmp_path / project_name / "models" / model_name / "trajectories" / "val" + traj_dir.mkdir(parents=True) + + groups = [ + TrajectoryGroup( + trajectories=[ + Trajectory( + reward=1.0, + messages_and_choices=[{"role": "user", "content": "hi"}], + ) + ], + metadata={"scenario_id": "abc"}, + metrics={"judge_score": 0.9}, + logs=["group log"], + exceptions=[], + ) + ] + write_trajectory_groups_parquet(groups, traj_dir / "0000.parquet") + + df = await load_trajectories( + project_name=project_name, + models=[model_name], + art_path=str(tmp_path), + ) + + assert "group_metric_judge_score" in df.columns + assert "group_metadata_scenario_id" in df.columns + assert df["group_metric_judge_score"][0] == 0.9 + assert df["group_metadata_scenario_id"][0] == "abc" diff --git a/tests/unit/test_frontend_logging.py b/tests/unit/test_frontend_logging.py index 8ae0f9453..eb0a1c595 100644 --- a/tests/unit/test_frontend_logging.py +++ b/tests/unit/test_frontend_logging.py @@ -75,6 +75,7 @@ async def test_parquet_readable_by_read_trajectory_groups_parquet( name="test-model", project="test-project", base_path=str(tmp_path), + report_metrics=[], ) # Mock get_step to return 0 for non-trainable model @@ -102,6 +103,7 @@ async def test_parquet_schema_preserved( name="test-model", project="test-project", base_path=str(tmp_path), + report_metrics=[], ) await model.log(sample_trajectory_groups, split="val") @@ -114,6 +116,9 @@ async def test_parquet_schema_preserved( # Check expected columns exist expected_columns = [ "group_index", + "group_metadata", + "group_metrics", + "group_logs", "reward", "metrics", "metadata", @@ -164,6 +169,7 @@ async def test_history_jsonl_format( name="test-model", project="test-project", base_path=str(tmp_path), + report_metrics=[], ) await model.log(sample_trajectory_groups, split="val") @@ -188,6 +194,7 @@ async def test_history_readable_by_polars( name="test-model", project="test-project", base_path=str(tmp_path), + report_metrics=[], ) await model.log(sample_trajectory_groups, split="val") @@ -208,6 +215,7 @@ async def test_history_appends_entries( name="test-model", project="test-project", base_path=str(tmp_path), + report_metrics=[], ) # Log twice @@ -236,6 +244,7 @@ async def test_file_locations_match_localbackend(self, tmp_path: Path): name="mymodel", project="myproj", base_path=str(tmp_path), + report_metrics=[], ) trajectories = [ @@ -267,6 +276,7 @@ async def test_step_numbering_format(self, tmp_path: Path): project="myproj", base_model="gpt-4", base_path=str(tmp_path), + report_metrics=[], ) # Mock the backend and get_step @@ -304,6 +314,7 @@ async def test_metric_prefixes(self, tmp_path: Path): name="test", project="test", base_path=str(tmp_path), + report_metrics=[], ) trajectories = [ @@ -338,6 +349,7 @@ async def test_standard_metrics_present(self, tmp_path: Path): name="test", project="test", base_path=str(tmp_path), + report_metrics=[], ) trajectory_groups = [ @@ -369,6 +381,47 @@ async def test_standard_metrics_present(self, tmp_path: Path): # Check reward average is correct assert entry["val/reward"] == 0.7 # (0.8 + 0.6) / 2 + @pytest.mark.asyncio + async def test_group_metric_aggregation(self, tmp_path: Path): + """Verify group-level metrics are aggregated once per group.""" + model = Model( + name="test", + project="test", + base_path=str(tmp_path), + report_metrics=[], + ) + + trajectory_groups = [ + TrajectoryGroup( + trajectories=[ + Trajectory( + reward=0.8, + messages_and_choices=[{"role": "user", "content": "a"}], + ) + ], + metrics={"judge_score": 0.2}, + exceptions=[], + ), + TrajectoryGroup( + trajectories=[ + Trajectory( + reward=0.6, + messages_and_choices=[{"role": "user", "content": "b"}], + ) + ], + metrics={"judge_score": 0.6}, + exceptions=[], + ), + ] + + await model.log(trajectory_groups, split="val") + + history_path = tmp_path / "test/models/test/history.jsonl" + with open(history_path) as f: + entry = json.loads(f.readline()) + + assert entry["val/group_metric_judge_score"] == 0.4 + @pytest.mark.asyncio async def test_exception_rate_calculation(self, tmp_path: Path): """Verify exception_rate is calculated correctly for successful trajectories.""" @@ -376,6 +429,7 @@ async def test_exception_rate_calculation(self, tmp_path: Path): name="test", project="test", base_path=str(tmp_path), + report_metrics=[], ) # TrajectoryGroup stores trajectories and exceptions separately diff --git a/tests/unit/test_trajectory_parquet.py b/tests/unit/test_trajectory_parquet.py index 597d93e71..c48608ee0 100644 --- a/tests/unit/test_trajectory_parquet.py +++ b/tests/unit/test_trajectory_parquet.py @@ -173,13 +173,50 @@ def test_tool_calls(self, tmp_path: Path): assert tool_calls, "Assistant message should include tool calls" first_call = tool_calls[0] assert first_call["type"] == "function" - function_call = cast(ChatCompletionMessageFunctionToolCallParam, first_call) # ty:ignore[redundant-cast] + function_call = cast(ChatCompletionMessageFunctionToolCallParam, first_call) assert function_call["function"]["name"] == "search" # Check tool result message tool_result_msg = _ensure_tool_message(traj.messages_and_choices[2]) assert tool_result_msg["tool_call_id"] == "call_123" + def test_group_level_fields_round_trip(self, tmp_path: Path): + """Group-level metadata/metrics/logs should survive round-trip.""" + original = [ + TrajectoryGroup( + trajectories=[ + Trajectory( + reward=0.4, + metrics={"idx": 0}, + metadata={}, + messages_and_choices=[{"role": "user", "content": "msg0"}], + logs=[], + ), + Trajectory( + reward=0.6, + metrics={"idx": 1}, + metadata={}, + messages_and_choices=[{"role": "user", "content": "msg1"}], + logs=[], + ), + ], + metadata={"scenario_id": "abc-123", "difficulty": "hard"}, + metrics={"judge_score": 0.7, "pass_rate": 1}, + logs=["group log 1", "group log 2"], + exceptions=[], + ) + ] + + parquet_path = tmp_path / "test.parquet" + write_trajectory_groups_parquet(original, parquet_path) + loaded = read_trajectory_groups_parquet(parquet_path) + + assert len(loaded) == 1 + group = loaded[0] + assert group.metadata == {"scenario_id": "abc-123", "difficulty": "hard"} + assert group.metrics == {"judge_score": 0.7, "pass_rate": 1} + assert group.logs == ["group log 1", "group log 2"] + def test_choice_format(self, tmp_path: Path): """Test trajectories with Choice format (finish_reason) are flattened to messages.""" original = [