Skip to content
59 changes: 59 additions & 0 deletions benchmark_v2/benchmark_scripts/continuous_batching_overall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import re
import subprocess
from pathlib import Path

from tabulate import tabulate


SCRIPT_LOCATION = (Path(__file__).parent.parent.parent / "examples/pytorch/continuous_batching.py").as_posix()
COMMON_ARGS = "--log-level WARNING --seed 0".split()


def run_and_parse_cb_example(args: list[str]) -> dict:
print(f"Benchmarking with args: {args}")
output = subprocess.check_output(
["python", SCRIPT_LOCATION] + args.split() + COMMON_ARGS,
# stderr=subprocess.DEVNULL,
)
pattern = r"CB generation took: ([\d.]+) seconds for (\d+) tokens\. ([\d.]+)tok/s"
match = re.search(pattern, output.decode("utf-8"))
if match is not None:
return {
"args": args,
"time_seconds": float(match.group(1)),
"num_tokens": int(match.group(2)),
"throughput_tok_per_sec": float(match.group(3)),
}
return {}


if __name__ == "__main__":
results = [
{
"args": "Arguments",
"time_seconds": "Duration (s)",
"num_tokens": "Generated tokens",
"throughput_tok_per_sec": "Throughput (tok/s)",
}
]

# Benchmark with different number of samples
results.append(run_and_parse_cb_example("--samples 10"))
results.append(run_and_parse_cb_example("--samples 50"))
results.append(run_and_parse_cb_example("--samples 100"))
results.append(run_and_parse_cb_example("--samples 500"))

# Benchmark with compile: default, flash attention 2 and sdpa
results.append(run_and_parse_cb_example("--samples 100 --compile"))
results.append(run_and_parse_cb_example("--samples 100 --compile --attn flash_attention_2"))
results.append(run_and_parse_cb_example("--samples 100 --compile --attn sdpa"))

# Benchmark with parallel decoding
results.append(run_and_parse_cb_example("--samples 50 --compile --num-return-sequences 8 --do-sample"))
results.append(run_and_parse_cb_example("--samples 100 --compile --num-return-sequences 4 --do-sample"))

# Benchmark with prefix sharing
results.append(run_and_parse_cb_example("--samples 500 --add-prefix --compile"))

print()
print(tabulate(results, tablefmt="github"))
18 changes: 17 additions & 1 deletion examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def batch_generate(
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
parser.add_argument("--num-return-sequences", type=int, default=1, help="Number of return sequences")

# Benchmark parameters
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
Expand All @@ -190,6 +191,7 @@ def batch_generate(
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
parser.add_argument("--profile", type=str, default=None)
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--seed", type=int, default=None, help="Random seed")

# Display parameters
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
Expand All @@ -210,6 +212,10 @@ def batch_generate(
else:
args.attn = "kernels-community/flash-attn3"

# Set seed
if args.seed is not None:
torch.manual_seed(args.seed)

# Create model
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
has_system_role = args.sliding_window == 0
Expand Down Expand Up @@ -272,17 +278,27 @@ def batch_generate(
inputs = inputs if isinstance(inputs, list) else inputs["input_ids"]
batched_inputs.append(inputs)

# If num_return_sequences > 1, automatically enable do_sample with a warning
do_sample = args.do_sample
if args.num_return_sequences != 1 and not args.do_sample:
logger.warning(
f"num_return_sequences={args.num_return_sequences} > 1, automatically enabling do_sample=True. "
"Set --do-sample explicitly to suppress this warning."
)
do_sample = True

# Prepare generation config
generation_cfg = GenerationConfig(
max_new_tokens=args.max_new_tokens,
use_cuda_graph=use_cuda_graph,
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=args.do_sample,
do_sample=do_sample,
temperature=0.8,
top_p=0.9,
num_blocks=args.num_blocks,
max_batch_tokens=args.max_batch_tokens,
num_return_sequences=args.num_return_sequences,
)

# Add a compile config if requested
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
# We add two extra tokens to the cache to handle padding and generally discard unwanted tokens
self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim)
self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim)
for _ in range(group_size):
new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -388,6 +388,29 @@ def mark_shareable_blocks_as_complete(self, state: RequestState) -> None:
prompt_ids=(state.initial_tokens + state.generated_tokens),
)

def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep makes sense!

"""Copy the cache from the source blocks to the forked blocks."""
source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32)
forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32)
Comment on lines +393 to +394
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that

Copy link
Copy Markdown
Collaborator Author

@remi-or remi-or Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it thanks for checking!

for key_cache, value_cache in zip(self.key_cache, self.value_cache):
key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
key_cache[forked_blocks] = key_cache[source_blocks]
value_cache[forked_blocks] = value_cache[source_blocks]
# FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape)
# This will allow for better .update and a single copy instead of one per cache tensor

def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]:
"""Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids)."""
# These lists will be the accumulators for the source and destination blocks for the cache copy
source_blocks, destination_blocks = [], []
# Main fork loop
for cm in self.group_cache_managers:
src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager)
source_blocks.extend(src_blocks)
destination_blocks.extend(dst_blocks)
return source_blocks, destination_blocks


# TODO: rework computation with the groups and their sizes
class PagedAttentionMemoryHandler:
Expand Down
86 changes: 86 additions & 0 deletions src/transformers/generation/continuous_batching/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,62 @@ def get_free_blocks(
# In both cases, we return the allocated block ids
return allocated_block_ids

def fork_blocks(
self, parent_blocks: list[int], num_forks: int, shareable: bool, group_id: int
) -> tuple[list[list[int]], list[int], list[int]]:
"""Fork a given list of (parent_blocks) as many times as (num_forks). If the blocks are (shareable), we use
reference on the blocks that are complete. Otherwise, we allocate new blocks and keep track of their indices to
later copy the physical cache. For instance, when forking 4 blocks for 2 children:

Parent blocks: [0, 1, 2, 3], with all blocks being complete except the last one (block 3).

----------------------------------------- IF BLOCKS ARE NOT SHAREABLE -----------------------------------------

Forked blocks lists: [[5, 6, 7, 8], [9, 10, 11, 12]]
Copy source: [0, 1, 2, 3, 0, 1, 2, 3]
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
Copy destination: [5, 6, 7, 8, 9, 10, 11, 12] → 8 blocks are newly allocated and copied

----------------------------------------- IF BLOCKS ARE SHAREABLE ---------------------------------------------

Forked blocks lists: [[0, 1, 2, 5], [0, 1, 2, 6]]
Copy source: [ 3, 3] (block 3 is not complete so it's copied, not referenced)
↓ ↓
Copy destination: [ 5, 6] → only 2 blocks are newly allocated and copied
"""
# First phase: reference all complete blocks
forked_by_reference = []

if shareable:
for block_id in parent_blocks:
block = self._id_to_block[block_id]
if block.is_complete:
forked_by_reference.append(block.id)
block.ref_count += num_forks
else:
break

# Early return if we have forked all blocks by reference
blocks_to_copy = len(parent_blocks) - len(forked_by_reference)
if blocks_to_copy == 0:
return [forked_by_reference[:] for _ in range(num_forks)], [], []

# From now on, each child will have its own list of blocks
forked_blocks_lists = []
copy_src = []
copy_dst = []

# Second phase: allocate new blocks if needed
parent_id = forked_by_reference[-1] if forked_by_reference else None
for _ in range(num_forks):
allocated_block_ids = self.get_free_blocks(blocks_to_copy, parent_id, shareable, group_id)
if allocated_block_ids is None:
return None, [], []
forked_blocks_lists.append(forked_by_reference + allocated_block_ids)
copy_src.extend(parent_blocks[-blocks_to_copy:])
copy_dst.extend(allocated_block_ids)
return forked_blocks_lists, copy_src, copy_dst

def increase_ref_count(self, block_id: int) -> None:
"""Increases the reference count of a given (block_id)."""
block = self._id_to_block[block_id]
Expand Down Expand Up @@ -243,6 +299,36 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""

def fork_blocks(
self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager
) -> tuple[list[int], list[int]]:
"""Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks,
the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to
the block_manager is used. When forking, the child's block are either shared with the parent, or they need to

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use parenthesis to denote arguments of the function, would be weird to change convention midway

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok did not know!

be copied from the parent. Hence we return two lists of blocks that need to be copied: one for the source and
one for the destination."""

# Sanity checks
if parent_request_id not in self.block_table:
raise ValueError(f"No block table found for request {parent_request_id}")

# Actual forking
parent_blocks = self.block_table[parent_request_id]
list_forked_blocks, copy_src, copy_dst = block_manager.fork_blocks(
parent_blocks=parent_blocks,
num_forks=len(children_request_ids),
shareable=self.uses_block_sharing,
group_id=self._index,
)
if list_forked_blocks is None:
raise ValueError(f"Failed to fork blocks for request {parent_request_id}")

# Update the block table for all children requests
for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks):
if children_request_id in self.block_table:
raise ValueError(f"Block table already exists for request {children_request_id}")
self.block_table[children_request_id] = forked_blocks
Comment on lines +326 to +329
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like we should not need to iterate twice on the list_forked_blocks if block_manager.fork_blocks updates the block table on the fly. But emcapsulation might be better this way?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, moving the check to inside the loop. Thanks!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh my bad, i mistook your point. We iterate twice on that list indeed, because block_manager.fork_blocks does not update the block table directly. I tried making it handle that part and it led to a messy function that was not all that readable. imo it's best to leave it clear for now and re-visit it if it's a hotspotfor optimization, which it might be? then again it's just an additional iteration on a small list.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah yeah, I am just trying to avoid any extra for looop when we can, from reading the code it appeared to be removable, but if its not no worries!

return copy_src, copy_dst


class FullAttentionCacheAllocator(CacheAllocator):
"""Cache manager for a group of full attention layers."""
Expand Down
Loading