Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions src/art/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,16 @@ async def gather_trajectory_groups(
max_exceptions=max_exceptions,
max_metrics=max_metrics,
)

# Fuse the after_each callback into the gather process
async def group_forward(g: Awaitable[TrajectoryGroup]):
group = await wrap_group_awaitable(g)
if group is None or after_each is None:
return group
return await after_each(group)

with set_gather_context(context):
future = asyncio.gather(*[wrap_group_awaitable(g) for g in groups])
future = asyncio.gather(*[group_forward(g) for g in groups])
total = sum(getattr(g, "_num_trajectories", 1) for g in groups)
context.pbar = tqdm.tqdm(desc=pbar_desc, total=total)
result_groups = await future
Expand All @@ -40,23 +48,14 @@ async def gather_trajectory_groups(
context.pbar.close()

# Filter out any None results that may have been returned due to handled exceptions
processed_groups: list[TrajectoryGroup] = [
g for g in result_groups if g is not None
]

# If an after_each callback was provided, await it and collect its return values.
if after_each is not None:
ae_processed_groups = await asyncio.gather(
*(after_each(g) for g in processed_groups)
)
processed_groups = []
for g in ae_processed_groups:
if g is None:
continue
if isinstance(g, list):
processed_groups.extend(g)
elif isinstance(g, TrajectoryGroup):
processed_groups.append(g)
processed_groups = []
for g in result_groups:
if g is None:
continue
if isinstance(g, list):
processed_groups.extend(g)
elif isinstance(g, TrajectoryGroup):
processed_groups.append(g)

return processed_groups

Expand Down