-
Notifications
You must be signed in to change notification settings - Fork 33.1k
[CB] Support the num_return_sequences argument
#42921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
94822dd
5165d9e
eb8152c
3bef652
9f2596e
ea36c6a
703b48d
87fe8fd
8e4d0c3
2e772d5
4e6415d
299a299
c53e3b2
fb733b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| import re | ||
| import subprocess | ||
| from pathlib import Path | ||
|
|
||
| from tabulate import tabulate | ||
|
|
||
|
|
||
| SCRIPT_LOCATION = (Path(__file__).parent.parent.parent / "examples/pytorch/continuous_batching.py").as_posix() | ||
| COMMON_ARGS = "--log-level WARNING --seed 0".split() | ||
|
|
||
|
|
||
| def run_and_parse_cb_example(args: list[str]) -> dict: | ||
| print(f"Benchmarking with args: {args}") | ||
| output = subprocess.check_output( | ||
| ["python", SCRIPT_LOCATION] + args.split() + COMMON_ARGS, | ||
| # stderr=subprocess.DEVNULL, | ||
| ) | ||
| pattern = r"CB generation took: ([\d.]+) seconds for (\d+) tokens\. ([\d.]+)tok/s" | ||
| match = re.search(pattern, output.decode("utf-8")) | ||
| if match is not None: | ||
| return { | ||
| "args": args, | ||
| "time_seconds": float(match.group(1)), | ||
| "num_tokens": int(match.group(2)), | ||
| "throughput_tok_per_sec": float(match.group(3)), | ||
| } | ||
| return {} | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| results = [ | ||
| { | ||
| "args": "Arguments", | ||
| "time_seconds": "Duration (s)", | ||
| "num_tokens": "Generated tokens", | ||
| "throughput_tok_per_sec": "Throughput (tok/s)", | ||
| } | ||
| ] | ||
|
|
||
| # Benchmark with different number of samples | ||
| results.append(run_and_parse_cb_example("--samples 10")) | ||
| results.append(run_and_parse_cb_example("--samples 50")) | ||
| results.append(run_and_parse_cb_example("--samples 100")) | ||
| results.append(run_and_parse_cb_example("--samples 500")) | ||
|
|
||
| # Benchmark with compile: default, flash attention 2 and sdpa | ||
| results.append(run_and_parse_cb_example("--samples 100 --compile")) | ||
| results.append(run_and_parse_cb_example("--samples 100 --compile --attn flash_attention_2")) | ||
| results.append(run_and_parse_cb_example("--samples 100 --compile --attn sdpa")) | ||
|
|
||
| # Benchmark with parallel decoding | ||
| results.append(run_and_parse_cb_example("--samples 50 --compile --num-return-sequences 8 --do-sample")) | ||
| results.append(run_and_parse_cb_example("--samples 100 --compile --num-return-sequences 4 --do-sample")) | ||
|
|
||
| # Benchmark with prefix sharing | ||
| results.append(run_and_parse_cb_example("--samples 500 --add-prefix --compile")) | ||
|
|
||
| print() | ||
| print(tabulate(results, tablefmt="github")) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -210,7 +210,7 @@ def __init__( | |
| self.key_cache: list[torch.Tensor] = [] | ||
| self.value_cache: list[torch.Tensor] = [] | ||
| # We add two extra tokens to the cache to handle padding and generally discard unwanted tokens | ||
| self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim) | ||
| self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim) | ||
| for _ in range(group_size): | ||
| new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) | ||
| new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) | ||
|
|
@@ -388,6 +388,29 @@ def mark_shareable_blocks_as_complete(self, state: RequestState) -> None: | |
| prompt_ids=(state.initial_tokens + state.generated_tokens), | ||
| ) | ||
|
|
||
| def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None: | ||
| """Copy the cache from the source blocks to the forked blocks.""" | ||
| source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32) | ||
| forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32) | ||
|
Comment on lines
+393
to
+394
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it thanks for checking! |
||
| for key_cache, value_cache in zip(self.key_cache, self.value_cache): | ||
| key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) | ||
| value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) | ||
| key_cache[forked_blocks] = key_cache[source_blocks] | ||
| value_cache[forked_blocks] = value_cache[source_blocks] | ||
| # FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape) | ||
| # This will allow for better .update and a single copy instead of one per cache tensor | ||
|
|
||
| def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]: | ||
| """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids).""" | ||
| # These lists will be the accumulators for the source and destination blocks for the cache copy | ||
| source_blocks, destination_blocks = [], [] | ||
| # Main fork loop | ||
| for cm in self.group_cache_managers: | ||
| src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager) | ||
| source_blocks.extend(src_blocks) | ||
| destination_blocks.extend(dst_blocks) | ||
| return source_blocks, destination_blocks | ||
|
|
||
|
|
||
| # TODO: rework computation with the groups and their sizes | ||
| class PagedAttentionMemoryHandler: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -123,6 +123,62 @@ def get_free_blocks( | |||||
| # In both cases, we return the allocated block ids | ||||||
| return allocated_block_ids | ||||||
|
|
||||||
| def fork_blocks( | ||||||
| self, parent_blocks: list[int], num_forks: int, shareable: bool, group_id: int | ||||||
| ) -> tuple[list[list[int]], list[int], list[int]]: | ||||||
| """Fork a given list of (parent_blocks) as many times as (num_forks). If the blocks are (shareable), we use | ||||||
| reference on the blocks that are complete. Otherwise, we allocate new blocks and keep track of their indices to | ||||||
| later copy the physical cache. For instance, when forking 4 blocks for 2 children: | ||||||
|
|
||||||
| Parent blocks: [0, 1, 2, 3], with all blocks being complete except the last one (block 3). | ||||||
|
|
||||||
| ----------------------------------------- IF BLOCKS ARE NOT SHAREABLE ----------------------------------------- | ||||||
|
|
||||||
| Forked blocks lists: [[5, 6, 7, 8], [9, 10, 11, 12]] | ||||||
| Copy source: [0, 1, 2, 3, 0, 1, 2, 3] | ||||||
| ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ | ||||||
| Copy destination: [5, 6, 7, 8, 9, 10, 11, 12] → 8 blocks are newly allocated and copied | ||||||
|
|
||||||
| ----------------------------------------- IF BLOCKS ARE SHAREABLE --------------------------------------------- | ||||||
|
|
||||||
| Forked blocks lists: [[0, 1, 2, 5], [0, 1, 2, 6]] | ||||||
| Copy source: [ 3, 3] (block 3 is not complete so it's copied, not referenced) | ||||||
| ↓ ↓ | ||||||
| Copy destination: [ 5, 6] → only 2 blocks are newly allocated and copied | ||||||
| """ | ||||||
| # First phase: reference all complete blocks | ||||||
| forked_by_reference = [] | ||||||
|
|
||||||
| if shareable: | ||||||
| for block_id in parent_blocks: | ||||||
| block = self._id_to_block[block_id] | ||||||
| if block.is_complete: | ||||||
| forked_by_reference.append(block.id) | ||||||
| block.ref_count += num_forks | ||||||
| else: | ||||||
| break | ||||||
|
|
||||||
| # Early return if we have forked all blocks by reference | ||||||
| blocks_to_copy = len(parent_blocks) - len(forked_by_reference) | ||||||
| if blocks_to_copy == 0: | ||||||
| return [forked_by_reference[:] for _ in range(num_forks)], [], [] | ||||||
|
|
||||||
| # From now on, each child will have its own list of blocks | ||||||
| forked_blocks_lists = [] | ||||||
| copy_src = [] | ||||||
| copy_dst = [] | ||||||
|
|
||||||
| # Second phase: allocate new blocks if needed | ||||||
| parent_id = forked_by_reference[-1] if forked_by_reference else None | ||||||
| for _ in range(num_forks): | ||||||
| allocated_block_ids = self.get_free_blocks(blocks_to_copy, parent_id, shareable, group_id) | ||||||
| if allocated_block_ids is None: | ||||||
| return None, [], [] | ||||||
| forked_blocks_lists.append(forked_by_reference + allocated_block_ids) | ||||||
| copy_src.extend(parent_blocks[-blocks_to_copy:]) | ||||||
| copy_dst.extend(allocated_block_ids) | ||||||
| return forked_blocks_lists, copy_src, copy_dst | ||||||
|
|
||||||
| def increase_ref_count(self, block_id: int) -> None: | ||||||
| """Increases the reference count of a given (block_id).""" | ||||||
| block = self._id_to_block[block_id] | ||||||
|
|
@@ -243,6 +299,36 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int | |||||
| def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]: | ||||||
| """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" | ||||||
|
|
||||||
| def fork_blocks( | ||||||
| self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager | ||||||
| ) -> tuple[list[int], list[int]]: | ||||||
| """Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks, | ||||||
| the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I use parenthesis to denote arguments of the function, would be weird to change convention midway
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok did not know! |
||||||
| be copied from the parent. Hence we return two lists of blocks that need to be copied: one for the source and | ||||||
| one for the destination.""" | ||||||
|
|
||||||
| # Sanity checks | ||||||
| if parent_request_id not in self.block_table: | ||||||
| raise ValueError(f"No block table found for request {parent_request_id}") | ||||||
|
|
||||||
| # Actual forking | ||||||
| parent_blocks = self.block_table[parent_request_id] | ||||||
| list_forked_blocks, copy_src, copy_dst = block_manager.fork_blocks( | ||||||
| parent_blocks=parent_blocks, | ||||||
| num_forks=len(children_request_ids), | ||||||
| shareable=self.uses_block_sharing, | ||||||
| group_id=self._index, | ||||||
| ) | ||||||
| if list_forked_blocks is None: | ||||||
| raise ValueError(f"Failed to fork blocks for request {parent_request_id}") | ||||||
|
|
||||||
| # Update the block table for all children requests | ||||||
| for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks): | ||||||
| if children_request_id in self.block_table: | ||||||
| raise ValueError(f"Block table already exists for request {children_request_id}") | ||||||
| self.block_table[children_request_id] = forked_blocks | ||||||
|
Comment on lines
+326
to
+329
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it feels like we should not need to iterate twice on the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, moving the check to inside the loop. Thanks!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh my bad, i mistook your point. We iterate twice on that list indeed, because
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah yeah, I am just trying to avoid any extra for looop when we can, from reading the code it appeared to be removable, but if its not no worries! |
||||||
| return copy_src, copy_dst | ||||||
|
|
||||||
|
|
||||||
| class FullAttentionCacheAllocator(CacheAllocator): | ||||||
| """Cache manager for a group of full attention layers.""" | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep makes sense!