From 94822dd82b861586eb99042e86f37cc7fa5af233 Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 16 Dec 2025 09:33:03 +0000 Subject: [PATCH 01/13] Reformat to make the code pretty --- .../continuous_batching/continuous_api.py | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 28a154465d68..16a6d5ba872c 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -760,29 +760,34 @@ def __init__( num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers """ - # Reloade paged version if necessary + # Reload paged version of the attention implementation if necessary if "paged|" not in model.config._attn_implementation: model.set_attn_implementation(f"paged|{model.config._attn_implementation}") + # Internal arguments self.model = model.eval() - generation_config = model.generation_config if generation_config is None else generation_config - self.generation_config = generation_config + self.manual_eviction = manual_eviction + self._allow_block_sharing = allow_block_sharing + self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created + self.input_queue = queue.Queue(maxsize=max_queue_size) self.output_queue = queue.Queue() self.stop_event = threading.Event() - self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) + self.batch_processor: ContinuousBatchProcessor | None = None self._generation_thread = None self._request_counter = 0 self._request_lock = threading.Lock() - self.model.generation_config.top_p = None + + # Generation config related arguments + generation_config = model.generation_config if generation_config is None else generation_config + self.generation_config = generation_config + self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) self.do_sample = getattr(generation_config, "do_sample", True) self.logit_processor = self.model._get_logits_processor(generation_config) - self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet - self.manual_eviction = manual_eviction - self.batch_processor: ContinuousBatchProcessor | None = None - self._allow_block_sharing = allow_block_sharing - self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created + # self.model.generation_config.top_p = None NOTE: figure out why this was here + + # Cuda graph behavior is determined below using either user-specified arguments or heuristics self.use_cuda_graph = self._decide_use_cuda_graphs( use_cuda_graph=getattr(generation_config, "use_cuda_graph", None), num_q_padding_intervals=num_q_padding_intervals, @@ -796,6 +801,7 @@ def __init__( num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS ) + # Log probability generation is not supported yet (TODO) if self.log_prob_generation: raise NotImplementedError("log_prob_generation is not supported yet") @@ -1227,23 +1233,24 @@ def generate_batch( # Initialize manager with the batch inputs results = {} num_requests = len(inputs) - with ( - self.continuous_batching_context_manager( - generation_config=generation_config, - num_q_cuda_graphs=num_q_padding_intervals, - num_kv_cuda_graphs=num_kv_padding_intervals, - allow_block_sharing=allow_block_sharing, - block=True, - timeout=5, - ) as manager, - logging_redirect_tqdm([logger]), - tqdm( - total=num_requests, - disable=(not progress_bar), - desc=f"Solving {num_requests} requests", - unit="request", - ) as pbar, - ): + # Prepare context managers for the main loop + manager_cm = self.continuous_batching_context_manager( + generation_config=generation_config, + num_q_cuda_graphs=num_q_padding_intervals, + num_kv_cuda_graphs=num_kv_padding_intervals, + allow_block_sharing=allow_block_sharing, + block=True, + timeout=5, + ) + logging_cm = logging_redirect_tqdm([logger]) + pbar_cm = tqdm( + total=num_requests, + disable=(not progress_bar), + desc=f"Solving {num_requests} requests", + unit="request", + ) + # Main loop + with manager_cm as manager, logging_cm, pbar_cm as pbar: try: manager.add_requests( inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps From 5165d9ec325c186aa44ccea3de6044242e59c04a Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 17 Dec 2025 11:25:53 +0000 Subject: [PATCH 02/13] Allow for multiple decoding sequences in CB --- examples/pytorch/continuous_batching.py | 13 ++++++- .../generation/continuous_batching/cache.py | 24 +++++++++++- .../continuous_batching/cache_manager.py | 39 +++++++++++++++++++ .../continuous_batching/continuous_api.py | 30 +++++++++++--- .../continuous_batching/requests.py | 27 ++++++++++++- .../continuous_batching/scheduler.py | 11 +++++- 6 files changed, 134 insertions(+), 10 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index ac395a455032..dffef83930bc 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -177,6 +177,7 @@ def batch_generate( parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None) parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile") parser.add_argument("--do-sample", action="store_true", help="Activate sampling") + parser.add_argument("--num-return-sequences", type=int, default=1, help="Number of return sequences") # Benchmark parameters parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate") @@ -272,17 +273,27 @@ def batch_generate( inputs = inputs if isinstance(inputs, list) else inputs["input_ids"] batched_inputs.append(inputs) + # If num_return_sequences > 1, automatically enable do_sample with a warning + do_sample = args.do_sample + if args.num_return_sequences != 1 and not args.do_sample: + logger.warning( + f"num_return_sequences={args.num_return_sequences} > 1, automatically enabling do_sample=True. " + "Set --do-sample explicitly to suppress this warning." + ) + do_sample = True + # Prepare generation config generation_cfg = GenerationConfig( max_new_tokens=args.max_new_tokens, use_cuda_graph=use_cuda_graph, eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, - do_sample=args.do_sample, + do_sample=do_sample, temperature=0.8, top_p=0.9, num_blocks=args.num_blocks, max_batch_tokens=args.max_batch_tokens, + num_return_sequences=args.num_return_sequences, ) # Add a compile config if requested diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index d235593b91bd..af386cfdabad 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -210,7 +210,7 @@ def __init__( self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] # We add two extra tokens to the cache to handle padding and generally discard unwanted tokens - self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim) + self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim) for _ in range(group_size): new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) @@ -388,6 +388,28 @@ def mark_shareable_blocks_as_complete(self, state: RequestState) -> None: prompt_ids=(state.initial_tokens + state.generated_tokens), ) + def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None: + """Copy the cache from the source blocks to the forked blocks.""" + source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32) + forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32) + for key_cache, value_cache in zip(self.key_cache, self.value_cache): + key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) + value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) + key_cache[forked_blocks] = key_cache[source_blocks] + value_cache[forked_blocks] = value_cache[source_blocks] + # FIXME: should be one copy for al CMs with only the changing blocks + # FIXME: even once per fork batch + + def fork_request(self, state: RequestState, new_request_id: str) -> RequestState: + """Fork a request into a new request. The new request is created by copying the state and updating the + request_id.""" + new_state = state.fork(new_request_id) + for cm in self.group_cache_managers: + source_blocks, forked_blocks = cm.fork_blocks(state.request_id, new_state.request_id, self._block_manager) + self.copy_cache(source_blocks, forked_blocks) + # FIXME: move it to the batch level + return new_state + # TODO: rework computation with the groups and their sizes class PagedAttentionMemoryHandler: diff --git a/src/transformers/generation/continuous_batching/cache_manager.py b/src/transformers/generation/continuous_batching/cache_manager.py index 7f585c595810..45321a96d12e 100644 --- a/src/transformers/generation/continuous_batching/cache_manager.py +++ b/src/transformers/generation/continuous_batching/cache_manager.py @@ -123,6 +123,27 @@ def get_free_blocks( # In both cases, we return the allocated block ids return allocated_block_ids + def fork_blocks(self, source_blocks: list[int], shareable: bool, group_id: int) -> list[int] | None: + """Fork a given list of (source_blocks) into a new list of forked_blocks. If the blocks are (shareable), we + reference the existing blocks when they are complete. Otherwise, we allocate new blocks if possible. The + (group_id) of the layer group the blocks belong to is also needed.""" + forked_blocks = [] + parent_id = None + for block_id in source_blocks: + block = self._id_to_block[block_id] + # If the block is shareable and complete, we just reference the existing block + if shareable and block.is_complete: + forked_blocks.append(block.id) + # Otherwise, we allocate a new block if possible + else: + # FIXME: from this point on, the blocks should be allowed as a bunch, not 1 by 1 + allocated_block_ids = self.get_free_blocks(1, parent_id, shareable, group_id) + if allocated_block_ids is None: + return None + forked_blocks.append(allocated_block_ids[0]) + parent_id = forked_blocks[-1] + return forked_blocks + def increase_ref_count(self, block_id: int) -> None: """Increases the reference count of a given (block_id).""" block = self._id_to_block[block_id] @@ -243,6 +264,24 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]: """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" + def fork_blocks(self, source_request_id: str, dst_request_id: str, block_manager: BlockManager) -> tuple[list[int], list[int]]: + """Fork the cache blocks for a given request_id into a new request_id.""" + if source_request_id not in self.block_table: + raise ValueError(f"No block table found for request {source_request_id}") + if dst_request_id in self.block_table: + raise ValueError(f"Block table already exists for request {dst_request_id}") + + source_blocks = self.block_table[source_request_id] + forked_blocks = block_manager.fork_blocks( + source_blocks=source_blocks, + shareable=self.uses_block_sharing, + group_id=self._index, + ) + if forked_blocks is None: + raise ValueError(f"Failed to fork blocks for request {source_request_id}") + + self.block_table[dst_request_id] = forked_blocks + return source_blocks, forked_blocks class FullAttentionCacheAllocator(CacheAllocator): """Cache manager for a group of full attention layers.""" diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 16a6d5ba872c..32d97943b77b 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -572,13 +572,20 @@ def _maybe_send_output(self, state: RequestState) -> None: def update_batch(self) -> None: """Update request states based on generated tokens.""" new_tokens = self._get_new_tokens(len(self.requests_in_batch)) - for i, state in enumerate(self.requests_in_batch): + current_logits_index = 0 + for state in self.requests_in_batch: # If the request has no remaining prompt ids, it means prefill has already ended or just finished if len(state.remaining_prefill_tokens) == 0: - self.metrics.record_ttft_metric(state.created_time, state.request_id) - state.status = RequestStatus.DECODING - token = new_tokens[i] + + # If there are no generated tokens yet, it means prefill just ended + if state.generated_len() == 0: + self.metrics.record_ttft_metric(state.created_time, state.request_id) + state.status = RequestStatus.DECODING + + token = new_tokens[current_logits_index] state.tokens_to_process = [token] + current_logits_index += 1 + # Update the request and stop if it is complete is_finished = state.update_and_check_completion(token) # We mark the completed blocks as such @@ -594,6 +601,16 @@ def update_batch(self) -> None: else: raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}") + # If some requests need to be forked, we do it now + while self.scheduler._requests_to_fork: + state = self.scheduler._requests_to_fork.pop() + num_children = state.num_children + state.num_children = 0 + for i in range(num_children): + # FIXME: if fork cant be done, create a new pending request without forking + new_request = self.cache.fork_request(state, f"{state.request_id}__child#{i}") + self.scheduler.active_requests[new_request.request_id] = new_request + if self.cache.get_num_free_blocks() == 0: raise ValueError("No more free blocks") @@ -784,6 +801,7 @@ def __init__( self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) self.do_sample = getattr(generation_config, "do_sample", True) self.logit_processor = self.model._get_logits_processor(generation_config) + self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1) # self.model.generation_config.top_p = None NOTE: figure out why this was here @@ -935,6 +953,7 @@ def add_request( state = RequestState( request_id=request_id, initial_tokens=list(input_ids), + num_children=self.num_return_sequences - 1, record_timestamps=record_timestamps, tokens_to_process=list(input_ids), max_new_tokens=max_new_tokens, @@ -1232,7 +1251,8 @@ def generate_batch( # Initialize manager with the batch inputs results = {} - num_requests = len(inputs) + gen_cfg = self.generation_config if generation_config is None else generation_config + num_requests = len(inputs) * gen_cfg.num_return_sequences # Prepare context managers for the main loop manager_cm = self.continuous_batching_context_manager( generation_config=generation_config, diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index 0dd1b0b2ce75..300c514fb9b5 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -101,6 +101,8 @@ class RequestState: Attributes: request_id (str): The ID of the generation request. + initial_tokens (list[int]): The initial prompt tokens. + num_children (int): The number of children requests full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. prompt_ids (list[int] | None): The tokens IDs currently being processed. remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). @@ -121,6 +123,7 @@ class RequestState: initial_tokens: list[int] # Initial prompt tokens # Optional fields record_timestamps: bool = False # Whether to record timestamps for the generated tokens + num_children: int = 0 # Number of children requests # Internal fields tokens_to_process: list[int] | None = None # Tokens IDs currently being processed remaining_prefill_tokens: list[int] = field(default_factory=list) # For split requests, prefill left to process @@ -181,7 +184,7 @@ def update_and_check_completion(self, token_id: int) -> bool: Returns: bool: True if the request is now complete, False otherwise """ - # Only update if we're in decoding state + # Only update if we're in decoding state # TODO: seems useless (always true) -- remove this if self.status != RequestStatus.DECODING: return False @@ -227,3 +230,25 @@ def to_generation_output(self): error=self.error, timestamps=self.timestamps, ) + + def fork(self, new_request_id: str) -> "RequestState": + """Fork the request into a new request with the same state expect for request_id, created_time and lifespan.""" + return RequestState( + request_id=new_request_id, + initial_tokens=self.initial_tokens[:], + record_timestamps=self.record_timestamps, + num_children=self.num_children, + tokens_to_process=self.tokens_to_process[:], + remaining_prefill_tokens=self.remaining_prefill_tokens[:], + generated_tokens=self.generated_tokens[:], + allocated_blocks=self.allocated_blocks, + position_offset=self.position_offset, + _status=self.status, + max_new_tokens=self.max_new_tokens, + eos_token_id=self.eos_token_id, + streaming=self.streaming, + created_time=time.perf_counter(), + error=self.error, + lifespan=(time.perf_counter(), -1), + _timestamps=self._timestamps, + ) diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py index 109cb635ea5c..42bd607a216f 100644 --- a/src/transformers/generation/continuous_batching/scheduler.py +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -36,6 +36,7 @@ def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = Fa self.retain_cache_on_finish = retain_cache_on_finish self._cancellation_lock = threading.Lock() self._requests_to_cancel: set[str] = set() + self._requests_to_fork: list[RequestState] = [] @traced def add_waiting_request(self, state: RequestState): @@ -151,8 +152,13 @@ def _prepare_request_for_processing( else: request_tokens = state.tokens_to_process + # If the request has one or more children we make sure not to prefill it entrirely + if state.num_children > 0 and token_budget >= len(request_tokens) - 1: + token_budget = len(request_tokens) - 1 + self._requests_to_fork.append(state) + + # Case: we can process the entire prompt/remainder if len(request_tokens) < token_budget: - # Can process the entire prompt/remainder if state.status == RequestStatus.PENDING: self.active_requests[state.request_id] = state state.status = RequestStatus.PREFILLING @@ -161,8 +167,9 @@ def _prepare_request_for_processing( state.status = RequestStatus.PREFILLING state.tokens_to_process = state.remaining_prefill_tokens state.remaining_prefill_tokens = [] + + # Otherwise: we need to split the request else: - # Need to split the request if state.status == RequestStatus.PENDING: self.active_requests[state.request_id] = state state.status = RequestStatus.PREFILLING_SPLIT From eb8152c0a14e5076a9c187aa495f09636daca0b4 Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 17 Dec 2025 11:53:24 +0000 Subject: [PATCH 03/13] Style --- .../generation/continuous_batching/cache_manager.py | 5 ++++- .../generation/continuous_batching/continuous_api.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache_manager.py b/src/transformers/generation/continuous_batching/cache_manager.py index 45321a96d12e..0a8de3757065 100644 --- a/src/transformers/generation/continuous_batching/cache_manager.py +++ b/src/transformers/generation/continuous_batching/cache_manager.py @@ -264,7 +264,9 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]: """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" - def fork_blocks(self, source_request_id: str, dst_request_id: str, block_manager: BlockManager) -> tuple[list[int], list[int]]: + def fork_blocks( + self, source_request_id: str, dst_request_id: str, block_manager: BlockManager + ) -> tuple[list[int], list[int]]: """Fork the cache blocks for a given request_id into a new request_id.""" if source_request_id not in self.block_table: raise ValueError(f"No block table found for request {source_request_id}") @@ -283,6 +285,7 @@ def fork_blocks(self, source_request_id: str, dst_request_id: str, block_manager self.block_table[dst_request_id] = forked_blocks return source_blocks, forked_blocks + class FullAttentionCacheAllocator(CacheAllocator): """Cache manager for a group of full attention layers.""" diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 32d97943b77b..d3cb677b76e1 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -576,7 +576,6 @@ def update_batch(self) -> None: for state in self.requests_in_batch: # If the request has no remaining prompt ids, it means prefill has already ended or just finished if len(state.remaining_prefill_tokens) == 0: - # If there are no generated tokens yet, it means prefill just ended if state.generated_len() == 0: self.metrics.record_ttft_metric(state.created_time, state.request_id) From 3bef6520fd6584e4ded966f26206ef6424ac058b Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 18 Dec 2025 15:12:17 +0000 Subject: [PATCH 04/13] Fix a generation config bug --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index cf9497d2a1f1..4e77d20c0f2e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1305,7 +1305,7 @@ def _get_logits_processor( if generation_config.do_sample: # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) - if generation_config.num_beams > 1: + if generation_config.num_beams is not None and generation_config.num_beams > 1: if isinstance(generation_config._eos_token_tensor, list): min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 elif isinstance(generation_config._eos_token_tensor, torch.Tensor): From 9f2596e3faeb8a73fb5cfb48f834a06b1467e15e Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 18 Dec 2025 15:20:12 +0000 Subject: [PATCH 05/13] Add seed to example --- examples/pytorch/continuous_batching.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index dffef83930bc..970b390ffbc2 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -191,6 +191,7 @@ def batch_generate( parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate") parser.add_argument("--profile", type=str, default=None) parser.add_argument("--metrics", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="Random seed") # Display parameters parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display") @@ -211,6 +212,10 @@ def batch_generate( else: args.attn = "kernels-community/flash-attn3" + # Set seed + if args.seed is not None: + torch.manual_seed(args.seed) + # Create model model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct" has_system_role = args.sliding_window == 0 From ea36c6ad7f5f4502a8539847ccd94b863608ddc2 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 18 Dec 2025 15:32:28 +0000 Subject: [PATCH 06/13] Batch forking --- .../generation/continuous_batching/cache.py | 17 ++-- .../continuous_batching/cache_manager.py | 89 ++++++++++++------- .../continuous_batching/continuous_api.py | 17 +++- 3 files changed, 81 insertions(+), 42 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index af386cfdabad..848e1694f94c 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -400,15 +400,16 @@ def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None # FIXME: should be one copy for al CMs with only the changing blocks # FIXME: even once per fork batch - def fork_request(self, state: RequestState, new_request_id: str) -> RequestState: - """Fork a request into a new request. The new request is created by copying the state and updating the - request_id.""" - new_state = state.fork(new_request_id) + def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]: + """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids).""" + # These lists will be the accumulators for the source and destination blocks for the cache copy + source_blocks, destination_blocks = [], [] + # Main fork loop for cm in self.group_cache_managers: - source_blocks, forked_blocks = cm.fork_blocks(state.request_id, new_state.request_id, self._block_manager) - self.copy_cache(source_blocks, forked_blocks) - # FIXME: move it to the batch level - return new_state + src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager) + source_blocks.extend(src_blocks) + destination_blocks.extend(dst_blocks) + return source_blocks, destination_blocks # TODO: rework computation with the groups and their sizes diff --git a/src/transformers/generation/continuous_batching/cache_manager.py b/src/transformers/generation/continuous_batching/cache_manager.py index 0a8de3757065..89d3b5730eee 100644 --- a/src/transformers/generation/continuous_batching/cache_manager.py +++ b/src/transformers/generation/continuous_batching/cache_manager.py @@ -123,26 +123,42 @@ def get_free_blocks( # In both cases, we return the allocated block ids return allocated_block_ids - def fork_blocks(self, source_blocks: list[int], shareable: bool, group_id: int) -> list[int] | None: - """Fork a given list of (source_blocks) into a new list of forked_blocks. If the blocks are (shareable), we - reference the existing blocks when they are complete. Otherwise, we allocate new blocks if possible. The - (group_id) of the layer group the blocks belong to is also needed.""" - forked_blocks = [] - parent_id = None - for block_id in source_blocks: + def fork_blocks( + self, parent_blocks: list[int], num_forks: int, shareable: bool, group_id: int + ) -> tuple[list[list[int]], list[int], list[int]]: + """Fork a given list of (parent_blocks) as many times as (num_forks). If the blocks are (shareable), we use + reference on the blocks that are complete. Otherwise, we allocate new blocks and keep track of their indices to + later copy the physical cache.""" + # First phase: reference all complete blocks + forked_by_reference = [] + for block_id in parent_blocks: block = self._id_to_block[block_id] - # If the block is shareable and complete, we just reference the existing block if shareable and block.is_complete: - forked_blocks.append(block.id) - # Otherwise, we allocate a new block if possible + forked_by_reference.append(block.id) + block.ref_count += num_forks else: - # FIXME: from this point on, the blocks should be allowed as a bunch, not 1 by 1 - allocated_block_ids = self.get_free_blocks(1, parent_id, shareable, group_id) - if allocated_block_ids is None: - return None - forked_blocks.append(allocated_block_ids[0]) - parent_id = forked_blocks[-1] - return forked_blocks + break + + # Early return if we have forked all blocks by reference + blocks_to_copy = len(parent_blocks) - len(forked_by_reference) + if blocks_to_copy == 0: + return [forked_by_reference[:] for _ in range(num_forks)], [], [] + + # From now on, each child will have its own list of blocks + forked_blocks_lists = [] + copy_src = [] + copy_dst = [] + + # Second phase: allocate new blocks if needed + parent_id = forked_by_reference[-1] if forked_by_reference else None + for _ in range(num_forks): + allocated_block_ids = self.get_free_blocks(blocks_to_copy, parent_id, shareable, group_id) + if allocated_block_ids is None: + return None, [], [] + forked_blocks_lists.append(forked_by_reference + allocated_block_ids) + copy_src.extend(parent_blocks[-blocks_to_copy:]) + copy_dst.extend(allocated_block_ids) + return forked_blocks_lists, copy_src, copy_dst def increase_ref_count(self, block_id: int) -> None: """Increases the reference count of a given (block_id).""" @@ -265,25 +281,36 @@ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" def fork_blocks( - self, source_request_id: str, dst_request_id: str, block_manager: BlockManager + self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager ) -> tuple[list[int], list[int]]: - """Fork the cache blocks for a given request_id into a new request_id.""" - if source_request_id not in self.block_table: - raise ValueError(f"No block table found for request {source_request_id}") - if dst_request_id in self.block_table: - raise ValueError(f"Block table already exists for request {dst_request_id}") - - source_blocks = self.block_table[source_request_id] - forked_blocks = block_manager.fork_blocks( - source_blocks=source_blocks, + """Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks, + the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to + be copied from the parent. Hence we return two lists of blocks that need to be copied: one for the source and + one for the destination.""" + + # Sanity checks + if parent_request_id not in self.block_table: + raise ValueError(f"No block table found for request {parent_request_id}") + # TODO: this check is good in the current context but it might be too much + slow things down + for children_request_id in children_request_ids: + if children_request_id in self.block_table: + raise ValueError(f"Block table already exists for request {children_request_id}") + + # Actual forking + parent_blocks = self.block_table[parent_request_id] + list_forked_blocks, copy_src, copy_dst = block_manager.fork_blocks( + parent_blocks=parent_blocks, + num_forks=len(children_request_ids), shareable=self.uses_block_sharing, group_id=self._index, ) - if forked_blocks is None: - raise ValueError(f"Failed to fork blocks for request {source_request_id}") + if list_forked_blocks is None: + raise ValueError(f"Failed to fork blocks for request {parent_request_id}") - self.block_table[dst_request_id] = forked_blocks - return source_blocks, forked_blocks + # Update the block table for all children requests + for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks): + self.block_table[children_request_id] = forked_blocks + return copy_src, copy_dst class FullAttentionCacheAllocator(CacheAllocator): diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index d3cb677b76e1..38769fe5c97f 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -601,15 +601,26 @@ def update_batch(self) -> None: raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}") # If some requests need to be forked, we do it now + copy_source, copy_destination = [], [] while self.scheduler._requests_to_fork: + # Get the number of children and reset it so it's not forked again state = self.scheduler._requests_to_fork.pop() num_children = state.num_children state.num_children = 0 - for i in range(num_children): - # FIXME: if fork cant be done, create a new pending request without forking - new_request = self.cache.fork_request(state, f"{state.request_id}__child#{i}") + # Create the new request + new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)] + new_requests = [state.fork(new_request_id) for new_request_id in new_request_ids] + # Fork the cache + copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids) + copy_source.extend(copy_src) + copy_destination.extend(copy_dst) + # Add the new requests to the scheduler + for new_request in new_requests: self.scheduler.active_requests[new_request.request_id] = new_request + # FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything + # The copy induced by the fork is done in one go + self.cache.copy_cache(copy_source, copy_destination) if self.cache.get_num_free_blocks() == 0: raise ValueError("No more free blocks") From 703b48d596a65476d0605c1f30b2e65061613bca Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 18 Dec 2025 17:39:59 +0000 Subject: [PATCH 07/13] Cahnge the fixme (for later PR) --- src/transformers/generation/continuous_batching/cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 848e1694f94c..bfae16a70f88 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -397,8 +397,8 @@ def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) key_cache[forked_blocks] = key_cache[source_blocks] value_cache[forked_blocks] = value_cache[source_blocks] - # FIXME: should be one copy for al CMs with only the changing blocks - # FIXME: even once per fork batch + # FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape) + # This will allow for better .update and a single copy instead of one per cache tensor def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]: """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids).""" From 87fe8fdb9fd15290b9ae213f2ce351c5a187ac4a Mon Sep 17 00:00:00 2001 From: remi-or Date: Fri, 19 Dec 2025 00:53:58 +0000 Subject: [PATCH 08/13] Copy source is optional --- .../generation/continuous_batching/continuous_api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 38769fe5c97f..72946500c98f 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -619,8 +619,10 @@ def update_batch(self) -> None: self.scheduler.active_requests[new_request.request_id] = new_request # FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything - # The copy induced by the fork is done in one go - self.cache.copy_cache(copy_source, copy_destination) + # The copy induced by the fork is done in one go (if it's even needed) + if copy_source: + self.cache.copy_cache(copy_source, copy_destination) + if self.cache.get_num_free_blocks() == 0: raise ValueError("No more free blocks") From 8e4d0c3fc365002f0e14b25cc9967f2efe36cc1f Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 24 Dec 2025 10:49:56 +0000 Subject: [PATCH 09/13] Added a benchmark script for PR --- .../continuous_batching_overall.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 benchmark_v2/benchmark_scripts/continuous_batching_overall.py diff --git a/benchmark_v2/benchmark_scripts/continuous_batching_overall.py b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py new file mode 100644 index 000000000000..1b820e4f6455 --- /dev/null +++ b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py @@ -0,0 +1,53 @@ +import re +import subprocess +from pathlib import Path + +from tabulate import tabulate + + +SCRIPT_LOCATION = (Path(__file__).parent.parent.parent / "examples/pytorch/continuous_batching.py").as_posix() +COMMON_ARGS = "--log-level WARNING --seed 0".split() + + +def run_and_parse_cb_example(args: list[str]) -> dict: + print(f"Benchmarking with args: {args}") + output = subprocess.check_output( + ["python", SCRIPT_LOCATION] + args.split() + COMMON_ARGS, + # stderr=subprocess.DEVNULL, + ) + pattern = r"CB generation took: ([\d.]+) seconds for (\d+) tokens\. ([\d.]+)tok/s" + match = re.search(pattern, output.decode("utf-8")) + if match is not None: + return { + "args": args, + "time_seconds": float(match.group(1)), + "num_tokens": int(match.group(2)), + "throughput_tok_per_sec": float(match.group(3)), + } + return {} + + +if __name__ == "__main__": + + results = [{"args": "Arguments", "time_seconds": "Duration (s)", "num_tokens": "Generated tokens", "throughput_tok_per_sec": "Throughput (tok/s)"}] + + # Benchmark with different number of samples + results.append(run_and_parse_cb_example("--samples 10")) + results.append(run_and_parse_cb_example("--samples 50")) + results.append(run_and_parse_cb_example("--samples 100")) + results.append(run_and_parse_cb_example("--samples 500")) + + # Benchmark with compile: default, flash attention 2 and sdpa + results.append(run_and_parse_cb_example("--samples 100 --compile")) + results.append(run_and_parse_cb_example("--samples 100 --compile --attn flash_attention_2")) + results.append(run_and_parse_cb_example("--samples 100 --compile --attn sdpa")) + + # Benchmark with parallel decoding + results.append(run_and_parse_cb_example("--samples 50 --compile --num-return-sequences 8 --do-sample")) + results.append(run_and_parse_cb_example("--samples 100 --compile --num-return-sequences 4 --do-sample")) + + # Benchmark with prefix sharing + results.append(run_and_parse_cb_example("--samples 500 --add-prefix --compile")) + + print() + print(tabulate(results, tablefmt="github")) From 2e772d5f24515cf915fc861803fd1b388af3fac1 Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 24 Dec 2025 11:22:46 +0000 Subject: [PATCH 10/13] Added a test and fixed a bug --- .../continuous_batching/cache_manager.py | 16 ++++--- tests/generation/test_continuous_batching.py | 44 ++++++++++++++++++- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache_manager.py b/src/transformers/generation/continuous_batching/cache_manager.py index 89d3b5730eee..7026eebb1253 100644 --- a/src/transformers/generation/continuous_batching/cache_manager.py +++ b/src/transformers/generation/continuous_batching/cache_manager.py @@ -131,13 +131,15 @@ def fork_blocks( later copy the physical cache.""" # First phase: reference all complete blocks forked_by_reference = [] - for block_id in parent_blocks: - block = self._id_to_block[block_id] - if shareable and block.is_complete: - forked_by_reference.append(block.id) - block.ref_count += num_forks - else: - break + + if shareable: + for block_id in parent_blocks: + block = self._id_to_block[block_id] + if block.is_complete: + forked_by_reference.append(block.id) + block.ref_count += num_forks + else: + break # Early return if we have forked all blocks by reference blocks_to_copy = len(parent_blocks) - len(forked_by_reference) diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 4bc7fca3204a..4f66e55cc740 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -212,7 +212,6 @@ def _test_continuous_batching_parity( chats = [[{"role": "user", "content": user_message}] for user_message in user_messages] tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats] input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized] - print(f"{input_ids[0] = } {type(input_ids[0]) = }") # Eager and SDPA implementations get a precision boost to account for the fact that an attention mask is used in # continuous batching but not in generate @@ -504,3 +503,46 @@ def test_block_sharing_with_hybrid_model(self) -> None: }).get_expectation() # fmt: skip return self._test_block_sharing(model_id, num_layer_groups, input_msg, expected_generated_tokens) + + + @parameterized.expand([True, False]) + @require_torch_accelerator + def test_num_return_sequences(self, allow_block_sharing: bool) -> None: + + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + user_messages = [ + "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?" + ] + chats = [[{"role": "user", "content": user_message}] for user_message in user_messages] + tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats] + input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized] + + # Generation with continuous batching + model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa") + model = model.to(torch_device).eval() + model.generation_config.max_new_tokens = 30 + model.generation_config.do_sample = False + + # Generation with continuous batching + manager_cm = model.continuous_batching_context_manager( + allow_block_sharing=allow_block_sharing, block=True, timeout=5 + ) + # Main loop + results = [] + with manager_cm as manager: + manager.num_return_sequences = 2 + manager.add_requests(inputs=input_ids, max_new_tokens=30) + requests_left = 2 + while requests_left: + result = manager.get_result(timeout=1) + if result and result.is_finished(): + results.append(result) + requests_left -= 1 + else: + if not manager.is_running(): + break + + self.assertEqual(len(results), 2, f"Expected 2 results, but got {len(results) = }") + self.assertEqual(results[0].generated_tokens, results[1].generated_tokens) + From 4e6415da184d4b5ef78726f13640e777881cc674 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 5 Jan 2026 09:41:39 +0000 Subject: [PATCH 11/13] Deepcopy and style --- .../continuous_batching_overall.py | 10 ++++++-- .../continuous_batching/requests.py | 25 +++++-------------- tests/generation/test_continuous_batching.py | 3 --- 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/benchmark_v2/benchmark_scripts/continuous_batching_overall.py b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py index 1b820e4f6455..720dce383485 100644 --- a/benchmark_v2/benchmark_scripts/continuous_batching_overall.py +++ b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py @@ -28,8 +28,14 @@ def run_and_parse_cb_example(args: list[str]) -> dict: if __name__ == "__main__": - - results = [{"args": "Arguments", "time_seconds": "Duration (s)", "num_tokens": "Generated tokens", "throughput_tok_per_sec": "Throughput (tok/s)"}] + results = [ + { + "args": "Arguments", + "time_seconds": "Duration (s)", + "num_tokens": "Generated tokens", + "throughput_tok_per_sec": "Throughput (tok/s)", + } + ] # Benchmark with different number of samples results.append(run_and_parse_cb_example("--samples 10")) diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index 300c514fb9b5..a5a12298790f 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from copy import deepcopy from dataclasses import dataclass, field from enum import Enum @@ -233,22 +234,8 @@ def to_generation_output(self): def fork(self, new_request_id: str) -> "RequestState": """Fork the request into a new request with the same state expect for request_id, created_time and lifespan.""" - return RequestState( - request_id=new_request_id, - initial_tokens=self.initial_tokens[:], - record_timestamps=self.record_timestamps, - num_children=self.num_children, - tokens_to_process=self.tokens_to_process[:], - remaining_prefill_tokens=self.remaining_prefill_tokens[:], - generated_tokens=self.generated_tokens[:], - allocated_blocks=self.allocated_blocks, - position_offset=self.position_offset, - _status=self.status, - max_new_tokens=self.max_new_tokens, - eos_token_id=self.eos_token_id, - streaming=self.streaming, - created_time=time.perf_counter(), - error=self.error, - lifespan=(time.perf_counter(), -1), - _timestamps=self._timestamps, - ) + new = deepcopy(self) + new.request_id = new_request_id + new.created_time = time.perf_counter() + new.lifespan = (new.created_time, -1) + return new diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 4f66e55cc740..3b803708d02b 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -504,11 +504,9 @@ def test_block_sharing_with_hybrid_model(self) -> None: return self._test_block_sharing(model_id, num_layer_groups, input_msg, expected_generated_tokens) - @parameterized.expand([True, False]) @require_torch_accelerator def test_num_return_sequences(self, allow_block_sharing: bool) -> None: - model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") user_messages = [ @@ -545,4 +543,3 @@ def test_num_return_sequences(self, allow_block_sharing: bool) -> None: self.assertEqual(len(results), 2, f"Expected 2 results, but got {len(results) = }") self.assertEqual(results[0].generated_tokens, results[1].generated_tokens) - From c53e3b26956f2015fd72c43b6c014cd64bca30ba Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 6 Jan 2026 12:08:34 +0000 Subject: [PATCH 12/13] Review compliance --- .../continuous_batching/cache_manager.py | 25 ++++++++++++++---- .../continuous_batching/continuous_api.py | 8 +++--- .../continuous_batching/requests.py | 26 +++++++++++++++---- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache_manager.py b/src/transformers/generation/continuous_batching/cache_manager.py index 7026eebb1253..c2186b00ee61 100644 --- a/src/transformers/generation/continuous_batching/cache_manager.py +++ b/src/transformers/generation/continuous_batching/cache_manager.py @@ -128,7 +128,24 @@ def fork_blocks( ) -> tuple[list[list[int]], list[int], list[int]]: """Fork a given list of (parent_blocks) as many times as (num_forks). If the blocks are (shareable), we use reference on the blocks that are complete. Otherwise, we allocate new blocks and keep track of their indices to - later copy the physical cache.""" + later copy the physical cache. For instance, when forking 4 blocks for 2 children: + + Parent blocks: [0, 1, 2, 3], with all blocks being complete except the last one (block 3). + + ----------------------------------------- IF BLOCKS ARE NOT SHAREABLE ----------------------------------------- + + Forked blocks lists: [[5, 6, 7, 8], [9, 10, 11, 12]] + Copy source: [0, 1, 2, 3, 0, 1, 2, 3] + ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ + Copy destination: [5, 6, 7, 8, 9, 10, 11, 12] → 8 blocks are newly allocated and copied + + ----------------------------------------- IF BLOCKS ARE SHAREABLE --------------------------------------------- + + Forked blocks lists: [[0, 1, 2, 5], [0, 1, 2, 6]] + Copy source: [ 3, 3] (block 3 is not complete so it's copied, not referenced) + ↓ ↓ + Copy destination: [ 5, 6] → only 2 blocks are newly allocated and copied + """ # First phase: reference all complete blocks forked_by_reference = [] @@ -293,10 +310,6 @@ def fork_blocks( # Sanity checks if parent_request_id not in self.block_table: raise ValueError(f"No block table found for request {parent_request_id}") - # TODO: this check is good in the current context but it might be too much + slow things down - for children_request_id in children_request_ids: - if children_request_id in self.block_table: - raise ValueError(f"Block table already exists for request {children_request_id}") # Actual forking parent_blocks = self.block_table[parent_request_id] @@ -311,6 +324,8 @@ def fork_blocks( # Update the block table for all children requests for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks): + if children_request_id in self.block_table: + raise ValueError(f"Block table already exists for request {children_request_id}") self.block_table[children_request_id] = forked_blocks return copy_src, copy_dst diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 72946500c98f..a3dc357a34b2 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -607,16 +607,14 @@ def update_batch(self) -> None: state = self.scheduler._requests_to_fork.pop() num_children = state.num_children state.num_children = 0 - # Create the new request + # Create the new request and add them to the scheduler new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)] - new_requests = [state.fork(new_request_id) for new_request_id in new_request_ids] + for new_request_id in new_request_ids: + self.scheduler.active_requests[new_request_id] = state.fork(new_request_id) # Fork the cache copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids) copy_source.extend(copy_src) copy_destination.extend(copy_dst) - # Add the new requests to the scheduler - for new_request in new_requests: - self.scheduler.active_requests[new_request.request_id] = new_request # FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything # The copy induced by the fork is done in one go (if it's even needed) diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index a5a12298790f..a1492d487dd2 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -234,8 +234,24 @@ def to_generation_output(self): def fork(self, new_request_id: str) -> "RequestState": """Fork the request into a new request with the same state expect for request_id, created_time and lifespan.""" - new = deepcopy(self) - new.request_id = new_request_id - new.created_time = time.perf_counter() - new.lifespan = (new.created_time, -1) - return new + t = time.perf_counter() + new_request = RequestState( + request_id=new_request_id, + initial_tokens=self.initial_tokens, + num_children=self.num_children, + tokens_to_process=self.tokens_to_process[:], + remaining_prefill_tokens=self.remaining_prefill_tokens[:], + generated_tokens=self.generated_tokens[:], + allocated_blocks=self.allocated_blocks, + position_offset=self.position_offset, + status=self.status, + max_new_tokens=self.max_new_tokens, + eos_token_id=self.eos_token_id, + streaming=self.streaming, + created_time=t, + lifespan=(t, -1), + timestamps=None if self.timestamps is None else self.timestamps[:], + error=self.error, + record_timestamps=self.record_timestamps, + ) + return new_request From fb733b231b11e3f5d3791a6a9b20628a84f3d0fc Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 6 Jan 2026 12:26:58 +0000 Subject: [PATCH 13/13] Style --- src/transformers/generation/continuous_batching/requests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index a1492d487dd2..8435164db6d3 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from copy import deepcopy from dataclasses import dataclass, field from enum import Enum