From 300d11d25a3c49b53c142b328a08620ed600599b Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 29 Jan 2026 11:29:30 -0700 Subject: [PATCH] refactor: streamline AutoTrajectoryContext management - Updated the context manager to return the trajectory directly upon entering, simplifying the capture_auto_trajectory function. - Ensured the trajectory is properly finished upon exiting the context, improving resource management. --- src/art/auto_trajectory.py | 8 ++++---- src/art/trajectories.py | 11 ++++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/art/auto_trajectory.py b/src/art/auto_trajectory.py index 4979602f7..c9e25d2b2 100644 --- a/src/art/auto_trajectory.py +++ b/src/art/auto_trajectory.py @@ -62,10 +62,8 @@ def auto_trajectory(*, required: bool = False) -> Trajectory | None: async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: - with AutoTrajectoryContext(): + with AutoTrajectoryContext() as trajectory: await coroutine - trajectory = auto_trajectory_context_var.get().trajectory - trajectory.finish() return trajectory @@ -76,11 +74,13 @@ def __init__(self) -> None: reward=0.0, ) - def __enter__(self) -> None: + def __enter__(self) -> Trajectory: self.token = auto_trajectory_context_var.set(self) + return self.trajectory def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: auto_trajectory_context_var.reset(self.token) + self.trajectory.finish() def handle_httpx_response(self, response: httpx._models.Response) -> None: # Get buffered content (set by patched aiter_bytes/iter_bytes) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 69ab17db9..1fc88582a 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -1,20 +1,21 @@ import asyncio -from contextlib import asynccontextmanager -from datetime import datetime import time import traceback +from contextlib import asynccontextmanager +from datetime import datetime from typing import ( Any, AsyncGenerator, Awaitable, + Coroutine, Iterable, Iterator, cast, overload, ) -from openai.types.chat.chat_completion import Choice import pydantic +from openai.types.chat.chat_completion import Choice from .types import Messages, MessagesAndChoices, Tools @@ -229,11 +230,11 @@ def __new__( ), *, exceptions: list[BaseException] = [], - ) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]": + ) -> "TrajectoryGroup | Coroutine[Any, Any, TrajectoryGroup]": ts = list(trajectories) if any(hasattr(t, "__await__") for t in ts): - async def _(exceptions: list[BaseException]): + async def _(exceptions: list[BaseException]) -> TrajectoryGroup: from .gather import get_gather_context, record_metrics context = get_gather_context()