Skip to content

[CB] Support the num_return_sequences argument#42921

Merged
remi-or merged 14 commits intomainfrom
cb-fork
Jan 6, 2026
Merged

[CB] Support the num_return_sequences argument#42921
remi-or merged 14 commits intomainfrom
cb-fork

Conversation

@remi-or
Copy link
Copy Markdown
Collaborator

@remi-or remi-or commented Dec 17, 2025

Summary

This PR adds the options to fork requests during continuous batching, which duplicates the request and uses as much as possible the existing cache. This is then leveraged to make the num_return_sequences argument available in CB.
This PR enables parallel decoding, which will be useful for RL workflows.

Performance

Samples Attention Add Prefix Source Duration (s) Generated tokens Throughput (tok/s)
100 flash_attention_2 False With PR 6.82 17480 2563.24
100 flash_attention_2 False On main 6.91 17698 2562.81
100 sdpa False With PR 21.77 17637 810.33
100 sdpa False On main 21.81 17234 790.03
500 flash_attention_3 True With PR 16.4 113054 6895.14
500 flash_attention_3 True On main 16.73 112333 6715.63
  • --compile is always on, thus the number of tokens varies between runs because the compiled kernels do

Tests

Tests pass, including the one added to test the feature.

Sanity check

Looks good

@remi-or remi-or self-assigned this Dec 17, 2025
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

prompt_ids=(state.initial_tokens + state.generated_tokens),
)

def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep makes sense!

Comment on lines +391 to +392
source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32)
forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that

Copy link
Copy Markdown
Collaborator Author

@remi-or remi-or Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it thanks for checking!

key_cache[forked_blocks] = key_cache[source_blocks]
value_cache[forked_blocks] = value_cache[source_blocks]
# FIXME: should be one copy for al CMs with only the changing blocks
# FIXME: even once per fork batch
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy should be async as well (async=True)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf. above

Comment on lines +406 to +407
source_blocks, forked_blocks = cm.fork_blocks(state.request_id, new_state.request_id, self._block_manager)
self.copy_cache(source_blocks, forked_blocks)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here we should "schedule" the copy. Remember we are in Python and the gil is killing us

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the copy is device to device, I think the best we can do for now is one copy as is the case right now. Plus, I think pytorch works in async in that case, ie. CPU operations continue after the copy is launched

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okayyy

Comment on lines +127 to +129
"""Fork a given list of (source_blocks) into a new list of forked_blocks. If the blocks are (shareable), we
reference the existing blocks when they are complete. Otherwise, we allocate new blocks if possible. The
(group_id) of the layer group the blocks belong to is also needed."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need do that shows in / out

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you means by this sorry!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I mean doc should help us more, showing example with input output!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an ascii table

Comment on lines +604 to +611
while self.scheduler._requests_to_fork:
state = self.scheduler._requests_to_fork.pop()
num_children = state.num_children
state.num_children = 0
for i in range(num_children):
# FIXME: if fork cant be done, create a new pending request without forking
new_request = self.cache.fork_request(state, f"{state.request_id}__child#{i}")
self.scheduler.active_requests[new_request.request_id] = new_request
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we should make that async IMO (new status "FORKING" -> wait until forked? IDK but we need to bench a tad bit

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has changed to be done in batch, which without asynchronous mode is the best we can do for CPU side.

Comment thread src/transformers/generation/continuous_batching/continuous_api.py
Comment thread src/transformers/generation/continuous_batching/continuous_api.py
Comment on lines +1256 to +1270
manager_cm = self.continuous_batching_context_manager(
generation_config=generation_config,
num_q_cuda_graphs=num_q_padding_intervals,
num_kv_cuda_graphs=num_kv_padding_intervals,
allow_block_sharing=allow_block_sharing,
block=True,
timeout=5,
)
logging_cm = logging_redirect_tqdm([logger])
pbar_cm = tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can create a get cm func?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand this -- what would it return? It seems self.continuous_batching_context_manager(...) is the get cm

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry I don't remember what I wanted here LGTM

Comment thread src/transformers/generation/continuous_batching/requests.py Outdated
Base automatically changed from cb-block-sharing to main December 18, 2025 11:28
@remi-or remi-or marked this pull request as ready for review December 24, 2025 13:58
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! def fork(self, new_request_id: str) -> "RequestState": should be optimized as much as possible,

prompt_ids=(state.initial_tokens + state.generated_tokens),
)

def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep makes sense!

Comment on lines +391 to +392
source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32)
forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it thanks for checking!

self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager
) -> tuple[list[int], list[int]]:
"""Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks,
the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to
the block_manager is used. When forking, the child's block are either shared with the parent, or they need to

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use parenthesis to denote arguments of the function, would be weird to change convention midway

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok did not know!

Comment on lines +313 to +314
for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks):
self.block_table[children_request_id] = forked_blocks
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like we should not need to iterate twice on the list_forked_blocks if block_manager.fork_blocks updates the block table on the fly. But emcapsulation might be better this way?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, moving the check to inside the loop. Thanks!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh my bad, i mistook your point. We iterate twice on that list indeed, because block_manager.fork_blocks does not update the block table directly. I tried making it handle that part and it led to a messy function that was not all that readable. imo it's best to leave it clear for now and re-visit it if it's a hotspotfor optimization, which it might be? then again it's just an additional iteration on a small list.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah yeah, I am just trying to avoid any extra for looop when we can, from reading the code it appeared to be removable, but if its not no worries!

Comment thread src/transformers/generation/continuous_batching/continuous_api.py Outdated
Comment thread src/transformers/generation/continuous_batching/continuous_api.py Outdated
Comment on lines +1256 to +1270
manager_cm = self.continuous_batching_context_manager(
generation_config=generation_config,
num_q_cuda_graphs=num_q_padding_intervals,
num_kv_cuda_graphs=num_kv_padding_intervals,
allow_block_sharing=allow_block_sharing,
block=True,
timeout=5,
)
logging_cm = logging_redirect_tqdm([logger])
pbar_cm = tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry I don't remember what I wanted here LGTM

Comment thread src/transformers/generation/continuous_batching/requests.py Outdated
@remi-or remi-or merged commit accb698 into main Jan 6, 2026
26 checks passed
@remi-or remi-or deleted the cb-fork branch January 6, 2026 12:40
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* Reformat to make the code pretty

* Allow for multiple decoding sequences in CB

* Style

* Fix a generation config bug

* Add seed to example

* Batch forking

* Cahnge the fixme (for later PR)

* Copy source is optional

* Added a benchmark script for PR

* Added a test and fixed a bug

* Deepcopy and style

* Review compliance

* Style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants