Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| prompt_ids=(state.initial_tokens + state.generated_tokens), | ||
| ) | ||
|
|
||
| def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None: |
There was a problem hiding this comment.
very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).
There was a problem hiding this comment.
On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.
| source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32) | ||
| forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32) |
There was a problem hiding this comment.
this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that
There was a problem hiding this comment.
I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later
There was a problem hiding this comment.
Got it thanks for checking!
| key_cache[forked_blocks] = key_cache[source_blocks] | ||
| value_cache[forked_blocks] = value_cache[source_blocks] | ||
| # FIXME: should be one copy for al CMs with only the changing blocks | ||
| # FIXME: even once per fork batch |
There was a problem hiding this comment.
copy should be async as well (async=True)
| source_blocks, forked_blocks = cm.fork_blocks(state.request_id, new_state.request_id, self._block_manager) | ||
| self.copy_cache(source_blocks, forked_blocks) |
There was a problem hiding this comment.
same here we should "schedule" the copy. Remember we are in Python and the gil is killing us
There was a problem hiding this comment.
Since the copy is device to device, I think the best we can do for now is one copy as is the case right now. Plus, I think pytorch works in async in that case, ie. CPU operations continue after the copy is launched
| """Fork a given list of (source_blocks) into a new list of forked_blocks. If the blocks are (shareable), we | ||
| reference the existing blocks when they are complete. Otherwise, we allocate new blocks if possible. The | ||
| (group_id) of the layer group the blocks belong to is also needed.""" |
There was a problem hiding this comment.
need do that shows in / out
There was a problem hiding this comment.
Not sure what you means by this sorry!
There was a problem hiding this comment.
sorry I mean doc should help us more, showing example with input output!
There was a problem hiding this comment.
Added an ascii table
| while self.scheduler._requests_to_fork: | ||
| state = self.scheduler._requests_to_fork.pop() | ||
| num_children = state.num_children | ||
| state.num_children = 0 | ||
| for i in range(num_children): | ||
| # FIXME: if fork cant be done, create a new pending request without forking | ||
| new_request = self.cache.fork_request(state, f"{state.request_id}__child#{i}") | ||
| self.scheduler.active_requests[new_request.request_id] = new_request |
There was a problem hiding this comment.
same here, we should make that async IMO (new status "FORKING" -> wait until forked? IDK but we need to bench a tad bit
There was a problem hiding this comment.
this has changed to be done in batch, which without asynchronous mode is the best we can do for CPU side.
| manager_cm = self.continuous_batching_context_manager( | ||
| generation_config=generation_config, | ||
| num_q_cuda_graphs=num_q_padding_intervals, | ||
| num_kv_cuda_graphs=num_kv_padding_intervals, | ||
| allow_block_sharing=allow_block_sharing, | ||
| block=True, | ||
| timeout=5, | ||
| ) | ||
| logging_cm = logging_redirect_tqdm([logger]) | ||
| pbar_cm = tqdm( | ||
| total=num_requests, | ||
| disable=(not progress_bar), | ||
| desc=f"Solving {num_requests} requests", | ||
| unit="request", | ||
| ) |
There was a problem hiding this comment.
you can create a get cm func?
There was a problem hiding this comment.
not sure I understand this -- what would it return? It seems self.continuous_batching_context_manager(...) is the get cm
There was a problem hiding this comment.
Yeah sorry I don't remember what I wanted here LGTM
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM! def fork(self, new_request_id: str) -> "RequestState": should be optimized as much as possible,
| prompt_ids=(state.initial_tokens + state.generated_tokens), | ||
| ) | ||
|
|
||
| def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None: |
| source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32) | ||
| forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32) |
There was a problem hiding this comment.
Got it thanks for checking!
| self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager | ||
| ) -> tuple[list[int], list[int]]: | ||
| """Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks, | ||
| the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to |
There was a problem hiding this comment.
| the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to | |
| the block_manager is used. When forking, the child's block are either shared with the parent, or they need to |
There was a problem hiding this comment.
I use parenthesis to denote arguments of the function, would be weird to change convention midway
| for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks): | ||
| self.block_table[children_request_id] = forked_blocks |
There was a problem hiding this comment.
it feels like we should not need to iterate twice on the list_forked_blocks if block_manager.fork_blocks updates the block table on the fly. But emcapsulation might be better this way?
There was a problem hiding this comment.
You are right, moving the check to inside the loop. Thanks!
There was a problem hiding this comment.
Oh my bad, i mistook your point. We iterate twice on that list indeed, because block_manager.fork_blocks does not update the block table directly. I tried making it handle that part and it led to a messy function that was not all that readable. imo it's best to leave it clear for now and re-visit it if it's a hotspotfor optimization, which it might be? then again it's just an additional iteration on a small list.
There was a problem hiding this comment.
yeah yeah, I am just trying to avoid any extra for looop when we can, from reading the code it appeared to be removable, but if its not no worries!
| manager_cm = self.continuous_batching_context_manager( | ||
| generation_config=generation_config, | ||
| num_q_cuda_graphs=num_q_padding_intervals, | ||
| num_kv_cuda_graphs=num_kv_padding_intervals, | ||
| allow_block_sharing=allow_block_sharing, | ||
| block=True, | ||
| timeout=5, | ||
| ) | ||
| logging_cm = logging_redirect_tqdm([logger]) | ||
| pbar_cm = tqdm( | ||
| total=num_requests, | ||
| disable=(not progress_bar), | ||
| desc=f"Solving {num_requests} requests", | ||
| unit="request", | ||
| ) |
There was a problem hiding this comment.
Yeah sorry I don't remember what I wanted here LGTM
* Reformat to make the code pretty * Allow for multiple decoding sequences in CB * Style * Fix a generation config bug * Add seed to example * Batch forking * Cahnge the fixme (for later PR) * Copy source is optional * Added a benchmark script for PR * Added a test and fixed a bug * Deepcopy and style * Review compliance * Style
Summary
This PR adds the options to fork requests during continuous batching, which duplicates the request and uses as much as possible the existing cache. This is then leveraged to make the
num_return_sequencesargument available in CB.This PR enables parallel decoding, which will be useful for RL workflows.
Performance
--compileis always on, thus the number of tokens varies between runs because the compiled kernels doTests
Tests pass, including the one added to test the feature.
Sanity check
Looks good