From 86c931c04eee80537332f51bf5c8b689977ac641 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Thu, 2 Oct 2025 23:08:41 +0000 Subject: [PATCH 1/5] Log metrics to W&B Models and Training Endpoints --- src/art/client.py | 14 +++++ src/art/local/backend.py | 30 +---------- src/art/serverless/backend.py | 81 ++++++++++++++++++++++++++--- src/art/utils/trajectory_logging.py | 30 +++++++++++ 4 files changed, 119 insertions(+), 36 deletions(-) diff --git a/src/art/client.py b/src/art/client.py index 53576b770..e99fb0e97 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -36,6 +36,10 @@ class DeleteCheckpointsResponse(BaseModel): not_found_steps: list[int] +class ReportMetricsResponse(BaseModel): + success: bool + + class Checkpoints(AsyncAPIResource): async def retrieve( self, *, model_id: str, step: int | Literal["latest"] @@ -81,6 +85,16 @@ async def delete( options=dict(max_retries=0), ) + async def report_metrics( + self, *, model_id: str, step: int | Literal["latest"], metrics: dict[str, float] + ) -> ReportMetricsResponse: + return await self._post( + f"/preview/models/{model_id}/checkpoints/{step}/report_metrics", + body={"metrics": metrics}, + cast_to=ReportMetricsResponse, + options=dict(max_retries=0), + ) + class Model(BaseModel): id: str diff --git a/src/art/local/backend.py b/src/art/local/backend.py index daf08adb0..8781e2cf1 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -39,7 +39,7 @@ pull_model_from_s3, push_model_to_s3, ) -from art.utils.trajectory_logging import serialize_trajectory_groups +from art.utils.trajectory_logging import get_metric_averages, serialize_trajectory_groups from mp_actors import close_proxy, move_to_child_process from .. import dev @@ -351,33 +351,7 @@ async def _log( with open(f"{parent_dir}/{file_name}", "w") as f: f.write(serialize_trajectory_groups(trajectory_groups)) - # Collect all metrics (including reward) across all trajectories - all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []} - - for group in trajectory_groups: - for trajectory in group: - if isinstance(trajectory, BaseException): - all_metrics["exception_rate"].append(1) - continue - else: - all_metrics["exception_rate"].append(0) - # Add reward metric - all_metrics["reward"].append(trajectory.reward) - - # Collect other custom metrics - for metric, value in trajectory.metrics.items(): - if metric not in all_metrics: - all_metrics[metric] = [] - all_metrics[metric].append(float(value)) - - # Calculate averages for all metrics - averages = {} - for metric, values in all_metrics.items(): - if len(values) > 0: - averages[metric] = sum(values) / len(values) - - # Calculate average standard deviation of rewards within groups - averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) + averages = get_metric_averages(trajectory_groups) self._log_metrics(model, averages, split) diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 0246a476d..facae587f 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,8 +1,14 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Literal, cast +import os from art.client import Client from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider +from art.utils.trajectory_logging import get_metric_averages +import wandb +import weave +from wandb.sdk.wandb_run import Run +from weave.trace.weave_client import WeaveClient from .. import dev from ..backend import Backend @@ -20,6 +26,8 @@ def __init__( client = Client(api_key=api_key, base_url=base_url) super().__init__(base_url=str(client.base_url)) self._client = client + self._wandb_runs: dict[str, Run] = {} + self._weave_clients: dict[str, WeaveClient] = {} async def close(self) -> None: await self._client.close() @@ -56,12 +64,16 @@ def _model_inference_name(self, model: "TrainableModel") -> str: assert model.entity is not None, "Model entity is required" return f"{model.entity}/{model.project}/{model.name}" - async def _get_step(self, model: "TrainableModel") -> int: - assert model.id is not None, "Model ID is required" - checkpoint = await self._client.checkpoints.retrieve( - model_id=model.id, step="latest" - ) - return checkpoint.step + async def __get_step(self, model: "Model") -> int: + if model.trainable: + model = cast(TrainableModel, model) + assert model.id is not None, "Model ID is required" + checkpoint = await self._client.checkpoints.retrieve( + model_id=model.id, step="latest" + ) + return checkpoint.step + # Non-trainable models do not have checkpoints/steps; default to 0 + return 0 async def _delete_checkpoints( self, @@ -99,7 +111,60 @@ async def _log( trajectory_groups: list[TrajectoryGroup], split: str = "val", ) -> None: - raise NotImplementedError + # TODO: log trajectories to local file system? + + averages = get_metric_averages(trajectory_groups) + await self._log_metrics(model, averages, split) + + async def _log_metrics( + self, + model: Model, + metrics: dict[str, float], + split: str, + step: int | None = None, + ) -> None: + metrics = {f"{split}/{metric}": value for metric, value in metrics.items()} + step = step if step is not None else await self.__get_step(model) + + # TODO: Write to history.jsonl like we do in LocalBackend? + + # If we have a W&B run, log the data there + if run := self._get_wandb_run(model): + # Mark the step metric itself as hidden so W&B doesn't create an automatic chart for it + wandb.define_metric("training_step", hidden=True) + + # Enabling the following line will cause W&B to use the training_step metric as the x-axis for all metrics + # wandb.define_metric(f"{split}/*", step_metric="training_step") + run.log({"training_step": step, **metrics}, step=step) + + # Report metrics to the W&B Training API + if model.trainable and model.id is not None: + await self._client.checkpoints.report_metrics( + model_id=model.id, step=step, metrics=metrics + ) + + + def _get_wandb_run(self, model: Model) -> Run | None: + if "WANDB_API_KEY" not in os.environ: + return None + if ( + model.name not in self._wandb_runs + or self._wandb_runs[model.name]._is_finished + ): + run = wandb.init( + project=model.project, + name=model.name, + id=model.name, + resume="allow", + ) + self._wandb_runs[model.name] = run + os.environ["WEAVE_PRINT_CALL_LINK"] = os.getenv( + "WEAVE_PRINT_CALL_LINK", "False" + ) + os.environ["WEAVE_LOG_LEVEL"] = os.getenv("WEAVE_LOG_LEVEL", "CRITICAL") + self._weave_clients[model.name] = weave.init(model.project) + return self._wandb_runs[model.name] + async def _train_model( self, diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py index 85bb2b653..78d61104e 100644 --- a/src/art/utils/trajectory_logging.py +++ b/src/art/utils/trajectory_logging.py @@ -6,6 +6,7 @@ from art import Trajectory, TrajectoryGroup from art.trajectories import History from art.types import Choice, Message, MessageOrChoice +from art.utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev # serialize trajectory groups to a jsonl string @@ -115,3 +116,32 @@ def dict_to_message_or_choice(dict: dict[str, Any]) -> MessageOrChoice: return Choice(**dict) else: return cast(Message, dict) + +def get_metric_averages(trajectory_groups: list[TrajectoryGroup]) -> dict[str, float]: + # Collect all metrics (including reward) across all trajectories + all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []} + + for group in trajectory_groups: + for trajectory in group: + if isinstance(trajectory, BaseException): + all_metrics["exception_rate"].append(1) + continue + else: + all_metrics["exception_rate"].append(0) + # Add reward metric + all_metrics["reward"].append(trajectory.reward) + + # Collect other custom metrics + for metric, value in trajectory.metrics.items(): + if metric not in all_metrics: + all_metrics[metric] = [] + all_metrics[metric].append(float(value)) + + # Calculate averages for all metrics + averages = {} + for metric, values in all_metrics.items(): + if len(values) > 0: + averages[metric] = sum(values) / len(values) + + # Calculate average standard deviation of rewards within groups + averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) \ No newline at end of file From 999624d2b78d2ab5daab7560ef6ebd46dd44de11 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Fri, 3 Oct 2025 00:51:40 +0000 Subject: [PATCH 2/5] Fix logging --- src/art/serverless/backend.py | 10 +++++----- src/art/utils/trajectory_logging.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index facae587f..0c8ef8b3a 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -64,9 +64,9 @@ def _model_inference_name(self, model: "TrainableModel") -> str: assert model.entity is not None, "Model entity is required" return f"{model.entity}/{model.project}/{model.name}" - async def __get_step(self, model: "Model") -> int: + + async def _get_step(self, model: "Model") -> int: if model.trainable: - model = cast(TrainableModel, model) assert model.id is not None, "Model ID is required" checkpoint = await self._client.checkpoints.retrieve( model_id=model.id, step="latest" @@ -118,13 +118,13 @@ async def _log( async def _log_metrics( self, - model: Model, + model: "Model", metrics: dict[str, float], split: str, step: int | None = None, ) -> None: metrics = {f"{split}/{metric}": value for metric, value in metrics.items()} - step = step if step is not None else await self.__get_step(model) + step = step if step is not None else await self._get_step(model) # TODO: Write to history.jsonl like we do in LocalBackend? @@ -144,7 +144,7 @@ async def _log_metrics( ) - def _get_wandb_run(self, model: Model) -> Run | None: + def _get_wandb_run(self, model: "Model") -> Run | None: if "WANDB_API_KEY" not in os.environ: return None if ( diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py index 78d61104e..c49396202 100644 --- a/src/art/utils/trajectory_logging.py +++ b/src/art/utils/trajectory_logging.py @@ -144,4 +144,6 @@ def get_metric_averages(trajectory_groups: list[TrajectoryGroup]) -> dict[str, f averages[metric] = sum(values) / len(values) # Calculate average standard deviation of rewards within groups - averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) \ No newline at end of file + averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) + + return averages \ No newline at end of file From d84040639147c1267a7664ca41d007c0d583d8bf Mon Sep 17 00:00:00 2001 From: arcticfly Date: Fri, 3 Oct 2025 10:28:47 +0000 Subject: [PATCH 3/5] Update serverless backend logging --- src/art/client.py | 27 ++++++++++++++ src/art/local/backend.py | 30 ++++++++++++++-- src/art/serverless/backend.py | 56 ++++------------------------- src/art/utils/trajectory_logging.py | 31 ---------------- 4 files changed, 61 insertions(+), 83 deletions(-) diff --git a/src/art/client.py b/src/art/client.py index e99fb0e97..4b751df98 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import os from typing import AsyncIterator, Iterable, Literal, TypedDict, cast @@ -40,6 +42,10 @@ class ReportMetricsResponse(BaseModel): success: bool +class LogResponse(BaseModel): + success: bool + + class Checkpoints(AsyncAPIResource): async def retrieve( self, *, model_id: str, step: int | Literal["latest"] @@ -95,6 +101,27 @@ async def report_metrics( options=dict(max_retries=0), ) + async def log_trajectories( + self, + *, + model_id: str, + trajectory_groups: list[TrajectoryGroup], + split: str = "val", + ) -> LogResponse: + return await self._post( + f"/preview/models/{model_id}/log", + body={ + "model_id": model_id, + "trajectory_groups": [ + trajectory_group.model_dump() + for trajectory_group in trajectory_groups + ], + "split": split, + }, + cast_to=LogResponse, + options=dict(max_retries=0), + ) + class Model(BaseModel): id: str diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 8781e2cf1..daf08adb0 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -39,7 +39,7 @@ pull_model_from_s3, push_model_to_s3, ) -from art.utils.trajectory_logging import get_metric_averages, serialize_trajectory_groups +from art.utils.trajectory_logging import serialize_trajectory_groups from mp_actors import close_proxy, move_to_child_process from .. import dev @@ -351,7 +351,33 @@ async def _log( with open(f"{parent_dir}/{file_name}", "w") as f: f.write(serialize_trajectory_groups(trajectory_groups)) - averages = get_metric_averages(trajectory_groups) + # Collect all metrics (including reward) across all trajectories + all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []} + + for group in trajectory_groups: + for trajectory in group: + if isinstance(trajectory, BaseException): + all_metrics["exception_rate"].append(1) + continue + else: + all_metrics["exception_rate"].append(0) + # Add reward metric + all_metrics["reward"].append(trajectory.reward) + + # Collect other custom metrics + for metric, value in trajectory.metrics.items(): + if metric not in all_metrics: + all_metrics[metric] = [] + all_metrics[metric].append(float(value)) + + # Calculate averages for all metrics + averages = {} + for metric, values in all_metrics.items(): + if len(values) > 0: + averages[metric] = sum(values) / len(values) + + # Calculate average standard deviation of rewards within groups + averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) self._log_metrics(model, averages, split) diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 0c8ef8b3a..0e5eab423 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -113,58 +113,14 @@ async def _log( ) -> None: # TODO: log trajectories to local file system? - averages = get_metric_averages(trajectory_groups) - await self._log_metrics(model, averages, split) - - async def _log_metrics( - self, - model: "Model", - metrics: dict[str, float], - split: str, - step: int | None = None, - ) -> None: - metrics = {f"{split}/{metric}": value for metric, value in metrics.items()} - step = step if step is not None else await self._get_step(model) - - # TODO: Write to history.jsonl like we do in LocalBackend? - - # If we have a W&B run, log the data there - if run := self._get_wandb_run(model): - # Mark the step metric itself as hidden so W&B doesn't create an automatic chart for it - wandb.define_metric("training_step", hidden=True) - - # Enabling the following line will cause W&B to use the training_step metric as the x-axis for all metrics - # wandb.define_metric(f"{split}/*", step_metric="training_step") - run.log({"training_step": step, **metrics}, step=step) - - # Report metrics to the W&B Training API - if model.trainable and model.id is not None: - await self._client.checkpoints.report_metrics( - model_id=model.id, step=step, metrics=metrics - ) + if not model.trainable: + print(f"Model {model.name} is not trainable; skipping logging.") + return + await self._client.checkpoints.log_trajectories( + model_id=model.id, trajectory_groups=trajectory_groups, split=split + ) - def _get_wandb_run(self, model: "Model") -> Run | None: - if "WANDB_API_KEY" not in os.environ: - return None - if ( - model.name not in self._wandb_runs - or self._wandb_runs[model.name]._is_finished - ): - run = wandb.init( - project=model.project, - name=model.name, - id=model.name, - resume="allow", - ) - self._wandb_runs[model.name] = run - os.environ["WEAVE_PRINT_CALL_LINK"] = os.getenv( - "WEAVE_PRINT_CALL_LINK", "False" - ) - os.environ["WEAVE_LOG_LEVEL"] = os.getenv("WEAVE_LOG_LEVEL", "CRITICAL") - self._weave_clients[model.name] = weave.init(model.project) - return self._wandb_runs[model.name] - async def _train_model( self, diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py index c49396202..80b6d5aa2 100644 --- a/src/art/utils/trajectory_logging.py +++ b/src/art/utils/trajectory_logging.py @@ -116,34 +116,3 @@ def dict_to_message_or_choice(dict: dict[str, Any]) -> MessageOrChoice: return Choice(**dict) else: return cast(Message, dict) - -def get_metric_averages(trajectory_groups: list[TrajectoryGroup]) -> dict[str, float]: - # Collect all metrics (including reward) across all trajectories - all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []} - - for group in trajectory_groups: - for trajectory in group: - if isinstance(trajectory, BaseException): - all_metrics["exception_rate"].append(1) - continue - else: - all_metrics["exception_rate"].append(0) - # Add reward metric - all_metrics["reward"].append(trajectory.reward) - - # Collect other custom metrics - for metric, value in trajectory.metrics.items(): - if metric not in all_metrics: - all_metrics[metric] = [] - all_metrics[metric].append(float(value)) - - # Calculate averages for all metrics - averages = {} - for metric, values in all_metrics.items(): - if len(values) > 0: - averages[metric] = sum(values) / len(values) - - # Calculate average standard deviation of rewards within groups - averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups) - - return averages \ No newline at end of file From f3bfc37e1a01b897215b8c878c4f7396a8eefd68 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Fri, 3 Oct 2025 11:26:54 +0000 Subject: [PATCH 4/5] Remove imports --- src/art/serverless/backend.py | 3 --- src/art/utils/trajectory_logging.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 0e5eab423..03dd8e37d 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -4,9 +4,6 @@ from art.client import Client from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider -from art.utils.trajectory_logging import get_metric_averages -import wandb -import weave from wandb.sdk.wandb_run import Run from weave.trace.weave_client import WeaveClient diff --git a/src/art/utils/trajectory_logging.py b/src/art/utils/trajectory_logging.py index 80b6d5aa2..85bb2b653 100644 --- a/src/art/utils/trajectory_logging.py +++ b/src/art/utils/trajectory_logging.py @@ -6,7 +6,6 @@ from art import Trajectory, TrajectoryGroup from art.trajectories import History from art.types import Choice, Message, MessageOrChoice -from art.utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev # serialize trajectory groups to a jsonl string From 82fff8fa0abaac6be306401299089e633f2362ab Mon Sep 17 00:00:00 2001 From: arcticfly Date: Fri, 3 Oct 2025 18:14:46 +0000 Subject: [PATCH 5/5] Remove report metrics stuff --- src/art/client.py | 13 ------------- src/art/serverless/backend.py | 4 ---- 2 files changed, 17 deletions(-) diff --git a/src/art/client.py b/src/art/client.py index 4b751df98..9bb56b6bc 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -38,9 +38,6 @@ class DeleteCheckpointsResponse(BaseModel): not_found_steps: list[int] -class ReportMetricsResponse(BaseModel): - success: bool - class LogResponse(BaseModel): success: bool @@ -91,16 +88,6 @@ async def delete( options=dict(max_retries=0), ) - async def report_metrics( - self, *, model_id: str, step: int | Literal["latest"], metrics: dict[str, float] - ) -> ReportMetricsResponse: - return await self._post( - f"/preview/models/{model_id}/checkpoints/{step}/report_metrics", - body={"metrics": metrics}, - cast_to=ReportMetricsResponse, - options=dict(max_retries=0), - ) - async def log_trajectories( self, *, diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 03dd8e37d..70185912d 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -4,8 +4,6 @@ from art.client import Client from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider -from wandb.sdk.wandb_run import Run -from weave.trace.weave_client import WeaveClient from .. import dev from ..backend import Backend @@ -23,8 +21,6 @@ def __init__( client = Client(api_key=api_key, base_url=base_url) super().__init__(base_url=str(client.base_url)) self._client = client - self._wandb_runs: dict[str, Run] = {} - self._weave_clients: dict[str, WeaveClient] = {} async def close(self) -> None: await self._client.close()