Skip to content

Suggestion: Potential major performance improvement for gather_trajectory_groups with after_each #342

@giladfrid009

Description

@giladfrid009

The current implementation of gather_trajectory_groups() is as follows:

async def gather_trajectory_groups(
    groups: Iterable[Awaitable[TrajectoryGroup]],
    after_each: Callable[
        [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]]
    ]
    | None = None,
) -> list[TrajectoryGroup]:
    
    ...

    # FIRST AWAIT:
    # First await all trajectory groups to finish, only then proceed

    with set_gather_context(context):
        future = asyncio.gather(*[wrap_group_awaitable(g) for g in groups])
        total = sum(getattr(g, "_num_trajectories", 1) for g in groups)
        context.pbar = tqdm.tqdm(desc=pbar_desc, total=total)
        result_groups = await future

    ...

    # SECOND AWAIT:
    # Only after *ALL* trajectory groups have been constructed, call the `after_each` callback for all of then

    # If an after_each callback was provided, await it and collect its return values.
    if after_each is not None:
        ae_processed_groups = await asyncio.gather(
            *(after_each(g) for g in processed_groups)
        )
    
    ...

    return processed_groups

Notice that in the current implementation, there are two separate asyncion.gather() calls, and the second gather which deals with the after_each callback will not fire up until the first gather is finished. This may cause severe performance bottlenecks, since if we use after_each=art.rewards.ruler_score_group, we wont be able to even begin scoring any group until all groups are ready.

See following image for demostration. Green and Red are the GPUs on which a local RULER model is deployed, Orange and Blue are GPUs on which rollouts are performed. Notice that there is NO OVERLAP - RULER GPUs are inactive throughout the entire rollout periods, and vice-versa.

Image

I propose to modify gather_trajectory_groups to have the following structure (or the spirit of it, im not 100% sure whats going on with the pbars and with the gather_context):

async def gather_trajectory_groups(
    groups: Iterable[Awaitable[TrajectoryGroup]],
    after_each: Callable[
        [TrajectoryGroup], Awaitable[TrajectoryGroup | None | list[TrajectoryGroup]]
    ]
    | None = None,
) -> list[TrajectoryGroup]:
    
    ...

   # A single awaitable function which constructs the trajectory group and invokes the callback if needed
   async def forward_group(g: TrajectoryGroup) -> TrajectoryGroup | None:
       g = await wrap_group_awaitable(g)
       if g and after_each:
           g = await after_each(g)
       return g

    # Simultaneously await group construction and `after_each` callbacks

    with set_gather_context(context):
        future = asyncio.gather(*[forward_group(g) for g in groups])
        total = sum(getattr(g, "_num_trajectories", 1) for g in groups)
        context.pbar = tqdm.tqdm(desc=pbar_desc, total=total)
        result_groups = await future

    ...


    return processed_groups

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions