diff --git a/src/art/auto_trajectory.py b/src/art/auto_trajectory.py index a22d0be7c..5a7fc5aee 100644 --- a/src/art/auto_trajectory.py +++ b/src/art/auto_trajectory.py @@ -62,10 +62,8 @@ def auto_trajectory(*, required: bool = False) -> Trajectory | None: async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: - with AutoTrajectoryContext(): + with AutoTrajectoryContext() as trajectory: await coroutine - trajectory = auto_trajectory_context_var.get().trajectory - trajectory.finish() return trajectory @@ -76,11 +74,13 @@ def __init__(self) -> None: reward=0.0, ) - def __enter__(self) -> None: + def __enter__(self) -> Trajectory: self.token = auto_trajectory_context_var.set(self) + return self.trajectory def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: auto_trajectory_context_var.reset(self.token) + self.trajectory.finish() def handle_httpx_response(self, response: httpx._models.Response) -> None: # Get buffered content (set by patched aiter_bytes/iter_bytes) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 69dd2d3de..1f39664c4 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -1,20 +1,21 @@ import asyncio -from contextlib import asynccontextmanager -from datetime import datetime import time import traceback +from contextlib import asynccontextmanager +from datetime import datetime from typing import ( Any, AsyncGenerator, Awaitable, + Coroutine, Iterable, Iterator, cast, overload, ) -from openai.types.chat.chat_completion import Choice import pydantic +from openai.types.chat.chat_completion import Choice from .types import Messages, MessagesAndChoices, Tools @@ -262,7 +263,7 @@ def __new__( metadata: dict[str, MetadataValue] | None = None, metrics: dict[str, float | int | bool] | None = None, logs: list[str] | None = None, - ) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]": + ) -> "TrajectoryGroup | Coroutine[Any, Any, TrajectoryGroup]": ts = list(trajectories) if any(hasattr(t, "__await__") for t in ts):