From 9b92d726d4b06746debeb8d6153d31f95087a185 Mon Sep 17 00:00:00 2001 From: Bohdan Date: Sat, 4 Oct 2025 17:47:54 -0700 Subject: [PATCH 1/2] add fail event --- src/art/client.py | 2 +- src/art/serverless/backend.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/art/client.py b/src/art/client.py index ceb9bf1fb..eb4b0c3ea 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -291,7 +291,7 @@ def events(self) -> TrainingJobEvents: class TrainingJobEvent(BaseModel): id: str - type: Literal["training_started", "gradient_step", "training_ended"] + type: Literal["training_started", "gradient_step", "training_ended", "training_failed"] data: dict[str, Any] diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 4d6a5d3be..0650ba705 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -156,6 +156,9 @@ async def _train_model( continue elif event.type == "training_ended": return + elif event.type == "training_failed": + error_message = event.data.get("error_message", "Training failed with an unknown error") + raise RuntimeError(f"Training job failed: {error_message}") after = event.id # ------------------------------------------------------------------ From 4deec26db0d52178ff142e3d3c40891e535e843e Mon Sep 17 00:00:00 2001 From: Bohdan Date: Sat, 4 Oct 2025 19:28:17 -0700 Subject: [PATCH 2/2] fix format --- src/art/client.py | 7 ++++--- src/art/gather.py | 3 ++- src/art/openai.py | 6 +++--- src/art/serverless/backend.py | 4 +++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/art/client.py b/src/art/client.py index eb4b0c3ea..8d23d4cd4 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, Iterable, Literal, TypedDict, cast import httpx +from openai import AsyncOpenAI, BaseModel, _exceptions from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options from openai._compat import cached_property from openai._qs import Querystring @@ -17,8 +18,6 @@ from openai.resources.models import AsyncModels # noqa: F401 from typing_extensions import override -from openai import AsyncOpenAI, BaseModel, _exceptions - from .trajectories import TrajectoryGroup @@ -291,7 +290,9 @@ def events(self) -> TrainingJobEvents: class TrainingJobEvent(BaseModel): id: str - type: Literal["training_started", "gradient_step", "training_ended", "training_failed"] + type: Literal[ + "training_started", "gradient_step", "training_ended", "training_failed" + ] data: dict[str, Any] diff --git a/src/art/gather.py b/src/art/gather.py index 830ce82d8..a9a37624c 100644 --- a/src/art/gather.py +++ b/src/art/gather.py @@ -190,7 +190,8 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None: if logprobs: # TODO: probably shouldn't average this trajectory.metrics["completion_tokens"] = sum( - len(l.content or l.refusal or []) for l in logprobs # noqa: E741 + len(l.content or l.refusal or []) + for l in logprobs # noqa: E741 ) / len(logprobs) context.metric_sums["reward"] += trajectory.reward # type: ignore context.metric_divisors["reward"] += 1 diff --git a/src/art/openai.py b/src/art/openai.py index 039f42a84..a56fcabdf 100644 --- a/src/art/openai.py +++ b/src/art/openai.py @@ -128,9 +128,9 @@ def update_chat_completion( choice.message.tool_calls[tool_call.index].id = tool_call.id if tool_call.function: if tool_call.function.name: - choice.message.tool_calls[tool_call.index].function.name = ( - tool_call.function.name - ) + choice.message.tool_calls[ + tool_call.index + ].function.name = tool_call.function.name if tool_call.function.arguments: choice.message.tool_calls[ tool_call.index diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 0650ba705..dda0c26a5 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -157,7 +157,9 @@ async def _train_model( elif event.type == "training_ended": return elif event.type == "training_failed": - error_message = event.data.get("error_message", "Training failed with an unknown error") + error_message = event.data.get( + "error_message", "Training failed with an unknown error" + ) raise RuntimeError(f"Training job failed: {error_message}") after = event.id