From 1336f75a86c826de3a91adead779bbb1f7802538 Mon Sep 17 00:00:00 2001 From: Gilad Freidkin Date: Wed, 27 Aug 2025 13:54:47 +0300 Subject: [PATCH] gather_trajectory_groups(): Fused after_each callback with group awaiting into a single async func, instead of two separate asyncio.gather() calls. --- src/art/gather.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/art/gather.py b/src/art/gather.py index 89eea5b41..5db1b9237 100644 --- a/src/art/gather.py +++ b/src/art/gather.py @@ -30,8 +30,16 @@ async def gather_trajectory_groups( max_exceptions=max_exceptions, max_metrics=max_metrics, ) + + # Fuse the after_each callback into the gather process + async def group_forward(g: Awaitable[TrajectoryGroup]): + group = await wrap_group_awaitable(g) + if group is None or after_each is None: + return group + return await after_each(group) + with set_gather_context(context): - future = asyncio.gather(*[wrap_group_awaitable(g) for g in groups]) + future = asyncio.gather(*[group_forward(g) for g in groups]) total = sum(getattr(g, "_num_trajectories", 1) for g in groups) context.pbar = tqdm.tqdm(desc=pbar_desc, total=total) result_groups = await future @@ -40,23 +48,14 @@ async def gather_trajectory_groups( context.pbar.close() # Filter out any None results that may have been returned due to handled exceptions - processed_groups: list[TrajectoryGroup] = [ - g for g in result_groups if g is not None - ] - - # If an after_each callback was provided, await it and collect its return values. - if after_each is not None: - ae_processed_groups = await asyncio.gather( - *(after_each(g) for g in processed_groups) - ) - processed_groups = [] - for g in ae_processed_groups: - if g is None: - continue - if isinstance(g, list): - processed_groups.extend(g) - elif isinstance(g, TrajectoryGroup): - processed_groups.append(g) + processed_groups = [] + for g in result_groups: + if g is None: + continue + if isinstance(g, list): + processed_groups.extend(g) + elif isinstance(g, TrajectoryGroup): + processed_groups.append(g) return processed_groups