diff --git a/benchmark_v2/benchmark_scripts/continuous_batching_overall.py b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py new file mode 100644 index 000000000000..720dce383485 --- /dev/null +++ b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py @@ -0,0 +1,59 @@ +import re +import subprocess +from pathlib import Path + +from tabulate import tabulate + + +SCRIPT_LOCATION = (Path(__file__).parent.parent.parent / "examples/pytorch/continuous_batching.py").as_posix() +COMMON_ARGS = "--log-level WARNING --seed 0".split() + + +def run_and_parse_cb_example(args: list[str]) -> dict: + print(f"Benchmarking with args: {args}") + output = subprocess.check_output( + ["python", SCRIPT_LOCATION] + args.split() + COMMON_ARGS, + # stderr=subprocess.DEVNULL, + ) + pattern = r"CB generation took: ([\d.]+) seconds for (\d+) tokens\. ([\d.]+)tok/s" + match = re.search(pattern, output.decode("utf-8")) + if match is not None: + return { + "args": args, + "time_seconds": float(match.group(1)), + "num_tokens": int(match.group(2)), + "throughput_tok_per_sec": float(match.group(3)), + } + return {} + + +if __name__ == "__main__": + results = [ + { + "args": "Arguments", + "time_seconds": "Duration (s)", + "num_tokens": "Generated tokens", + "throughput_tok_per_sec": "Throughput (tok/s)", + } + ] + + # Benchmark with different number of samples + results.append(run_and_parse_cb_example("--samples 10")) + results.append(run_and_parse_cb_example("--samples 50")) + results.append(run_and_parse_cb_example("--samples 100")) + results.append(run_and_parse_cb_example("--samples 500")) + + # Benchmark with compile: default, flash attention 2 and sdpa + results.append(run_and_parse_cb_example("--samples 100 --compile")) + results.append(run_and_parse_cb_example("--samples 100 --compile --attn flash_attention_2")) + results.append(run_and_parse_cb_example("--samples 100 --compile --attn sdpa")) + + # Benchmark with parallel decoding + results.append(run_and_parse_cb_example("--samples 50 --compile --num-return-sequences 8 --do-sample")) + results.append(run_and_parse_cb_example("--samples 100 --compile --num-return-sequences 4 --do-sample")) + + # Benchmark with prefix sharing + results.append(run_and_parse_cb_example("--samples 500 --add-prefix --compile")) + + print() + print(tabulate(results, tablefmt="github")) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index ac395a455032..970b390ffbc2 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -177,6 +177,7 @@ def batch_generate( parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None) parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile") parser.add_argument("--do-sample", action="store_true", help="Activate sampling") + parser.add_argument("--num-return-sequences", type=int, default=1, help="Number of return sequences") # Benchmark parameters parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate") @@ -190,6 +191,7 @@ def batch_generate( parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate") parser.add_argument("--profile", type=str, default=None) parser.add_argument("--metrics", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="Random seed") # Display parameters parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display") @@ -210,6 +212,10 @@ def batch_generate( else: args.attn = "kernels-community/flash-attn3" + # Set seed + if args.seed is not None: + torch.manual_seed(args.seed) + # Create model model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct" has_system_role = args.sliding_window == 0 @@ -272,17 +278,27 @@ def batch_generate( inputs = inputs if isinstance(inputs, list) else inputs["input_ids"] batched_inputs.append(inputs) + # If num_return_sequences > 1, automatically enable do_sample with a warning + do_sample = args.do_sample + if args.num_return_sequences != 1 and not args.do_sample: + logger.warning( + f"num_return_sequences={args.num_return_sequences} > 1, automatically enabling do_sample=True. " + "Set --do-sample explicitly to suppress this warning." + ) + do_sample = True + # Prepare generation config generation_cfg = GenerationConfig( max_new_tokens=args.max_new_tokens, use_cuda_graph=use_cuda_graph, eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, - do_sample=args.do_sample, + do_sample=do_sample, temperature=0.8, top_p=0.9, num_blocks=args.num_blocks, max_batch_tokens=args.max_batch_tokens, + num_return_sequences=args.num_return_sequences, ) # Add a compile config if requested diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index d235593b91bd..bfae16a70f88 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -210,7 +210,7 @@ def __init__( self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] # We add two extra tokens to the cache to handle padding and generally discard unwanted tokens - self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim) + self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim) for _ in range(group_size): new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) @@ -388,6 +388,29 @@ def mark_shareable_blocks_as_complete(self, state: RequestState) -> None: prompt_ids=(state.initial_tokens + state.generated_tokens), ) + def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None: + """Copy the cache from the source blocks to the forked blocks.""" + source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32) + forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32) + for key_cache, value_cache in zip(self.key_cache, self.value_cache): + key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) + value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) + key_cache[forked_blocks] = key_cache[source_blocks] + value_cache[forked_blocks] = value_cache[source_blocks] + # FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape) + # This will allow for better .update and a single copy instead of one per cache tensor + + def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]: + """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids).""" + # These lists will be the accumulators for the source and destination blocks for the cache copy + source_blocks, destination_blocks = [], [] + # Main fork loop + for cm in self.group_cache_managers: + src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager) + source_blocks.extend(src_blocks) + destination_blocks.extend(dst_blocks) + return source_blocks, destination_blocks + # TODO: rework computation with the groups and their sizes class PagedAttentionMemoryHandler: diff --git a/src/transformers/generation/continuous_batching/cache_manager.py b/src/transformers/generation/continuous_batching/cache_manager.py index 7f585c595810..c2186b00ee61 100644 --- a/src/transformers/generation/continuous_batching/cache_manager.py +++ b/src/transformers/generation/continuous_batching/cache_manager.py @@ -123,6 +123,62 @@ def get_free_blocks( # In both cases, we return the allocated block ids return allocated_block_ids + def fork_blocks( + self, parent_blocks: list[int], num_forks: int, shareable: bool, group_id: int + ) -> tuple[list[list[int]], list[int], list[int]]: + """Fork a given list of (parent_blocks) as many times as (num_forks). If the blocks are (shareable), we use + reference on the blocks that are complete. Otherwise, we allocate new blocks and keep track of their indices to + later copy the physical cache. For instance, when forking 4 blocks for 2 children: + + Parent blocks: [0, 1, 2, 3], with all blocks being complete except the last one (block 3). + + ----------------------------------------- IF BLOCKS ARE NOT SHAREABLE ----------------------------------------- + + Forked blocks lists: [[5, 6, 7, 8], [9, 10, 11, 12]] + Copy source: [0, 1, 2, 3, 0, 1, 2, 3] + ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ + Copy destination: [5, 6, 7, 8, 9, 10, 11, 12] → 8 blocks are newly allocated and copied + + ----------------------------------------- IF BLOCKS ARE SHAREABLE --------------------------------------------- + + Forked blocks lists: [[0, 1, 2, 5], [0, 1, 2, 6]] + Copy source: [ 3, 3] (block 3 is not complete so it's copied, not referenced) + ↓ ↓ + Copy destination: [ 5, 6] → only 2 blocks are newly allocated and copied + """ + # First phase: reference all complete blocks + forked_by_reference = [] + + if shareable: + for block_id in parent_blocks: + block = self._id_to_block[block_id] + if block.is_complete: + forked_by_reference.append(block.id) + block.ref_count += num_forks + else: + break + + # Early return if we have forked all blocks by reference + blocks_to_copy = len(parent_blocks) - len(forked_by_reference) + if blocks_to_copy == 0: + return [forked_by_reference[:] for _ in range(num_forks)], [], [] + + # From now on, each child will have its own list of blocks + forked_blocks_lists = [] + copy_src = [] + copy_dst = [] + + # Second phase: allocate new blocks if needed + parent_id = forked_by_reference[-1] if forked_by_reference else None + for _ in range(num_forks): + allocated_block_ids = self.get_free_blocks(blocks_to_copy, parent_id, shareable, group_id) + if allocated_block_ids is None: + return None, [], [] + forked_blocks_lists.append(forked_by_reference + allocated_block_ids) + copy_src.extend(parent_blocks[-blocks_to_copy:]) + copy_dst.extend(allocated_block_ids) + return forked_blocks_lists, copy_src, copy_dst + def increase_ref_count(self, block_id: int) -> None: """Increases the reference count of a given (block_id).""" block = self._id_to_block[block_id] @@ -243,6 +299,36 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]: """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" + def fork_blocks( + self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager + ) -> tuple[list[int], list[int]]: + """Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks, + the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to + be copied from the parent. Hence we return two lists of blocks that need to be copied: one for the source and + one for the destination.""" + + # Sanity checks + if parent_request_id not in self.block_table: + raise ValueError(f"No block table found for request {parent_request_id}") + + # Actual forking + parent_blocks = self.block_table[parent_request_id] + list_forked_blocks, copy_src, copy_dst = block_manager.fork_blocks( + parent_blocks=parent_blocks, + num_forks=len(children_request_ids), + shareable=self.uses_block_sharing, + group_id=self._index, + ) + if list_forked_blocks is None: + raise ValueError(f"Failed to fork blocks for request {parent_request_id}") + + # Update the block table for all children requests + for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks): + if children_request_id in self.block_table: + raise ValueError(f"Block table already exists for request {children_request_id}") + self.block_table[children_request_id] = forked_blocks + return copy_src, copy_dst + class FullAttentionCacheAllocator(CacheAllocator): """Cache manager for a group of full attention layers.""" diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 28a154465d68..a3dc357a34b2 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -572,13 +572,19 @@ def _maybe_send_output(self, state: RequestState) -> None: def update_batch(self) -> None: """Update request states based on generated tokens.""" new_tokens = self._get_new_tokens(len(self.requests_in_batch)) - for i, state in enumerate(self.requests_in_batch): + current_logits_index = 0 + for state in self.requests_in_batch: # If the request has no remaining prompt ids, it means prefill has already ended or just finished if len(state.remaining_prefill_tokens) == 0: - self.metrics.record_ttft_metric(state.created_time, state.request_id) - state.status = RequestStatus.DECODING - token = new_tokens[i] + # If there are no generated tokens yet, it means prefill just ended + if state.generated_len() == 0: + self.metrics.record_ttft_metric(state.created_time, state.request_id) + state.status = RequestStatus.DECODING + + token = new_tokens[current_logits_index] state.tokens_to_process = [token] + current_logits_index += 1 + # Update the request and stop if it is complete is_finished = state.update_and_check_completion(token) # We mark the completed blocks as such @@ -594,6 +600,27 @@ def update_batch(self) -> None: else: raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}") + # If some requests need to be forked, we do it now + copy_source, copy_destination = [], [] + while self.scheduler._requests_to_fork: + # Get the number of children and reset it so it's not forked again + state = self.scheduler._requests_to_fork.pop() + num_children = state.num_children + state.num_children = 0 + # Create the new request and add them to the scheduler + new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)] + for new_request_id in new_request_ids: + self.scheduler.active_requests[new_request_id] = state.fork(new_request_id) + # Fork the cache + copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids) + copy_source.extend(copy_src) + copy_destination.extend(copy_dst) + # FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything + + # The copy induced by the fork is done in one go (if it's even needed) + if copy_source: + self.cache.copy_cache(copy_source, copy_destination) + if self.cache.get_num_free_blocks() == 0: raise ValueError("No more free blocks") @@ -760,29 +787,35 @@ def __init__( num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers """ - # Reloade paged version if necessary + # Reload paged version of the attention implementation if necessary if "paged|" not in model.config._attn_implementation: model.set_attn_implementation(f"paged|{model.config._attn_implementation}") + # Internal arguments self.model = model.eval() - generation_config = model.generation_config if generation_config is None else generation_config - self.generation_config = generation_config + self.manual_eviction = manual_eviction + self._allow_block_sharing = allow_block_sharing + self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created + self.input_queue = queue.Queue(maxsize=max_queue_size) self.output_queue = queue.Queue() self.stop_event = threading.Event() - self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) + self.batch_processor: ContinuousBatchProcessor | None = None self._generation_thread = None self._request_counter = 0 self._request_lock = threading.Lock() - self.model.generation_config.top_p = None + + # Generation config related arguments + generation_config = model.generation_config if generation_config is None else generation_config + self.generation_config = generation_config + self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) self.do_sample = getattr(generation_config, "do_sample", True) self.logit_processor = self.model._get_logits_processor(generation_config) - self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet - self.manual_eviction = manual_eviction - self.batch_processor: ContinuousBatchProcessor | None = None - self._allow_block_sharing = allow_block_sharing - self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created + self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1) + + # self.model.generation_config.top_p = None NOTE: figure out why this was here + # Cuda graph behavior is determined below using either user-specified arguments or heuristics self.use_cuda_graph = self._decide_use_cuda_graphs( use_cuda_graph=getattr(generation_config, "use_cuda_graph", None), num_q_padding_intervals=num_q_padding_intervals, @@ -796,6 +829,7 @@ def __init__( num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS ) + # Log probability generation is not supported yet (TODO) if self.log_prob_generation: raise NotImplementedError("log_prob_generation is not supported yet") @@ -929,6 +963,7 @@ def add_request( state = RequestState( request_id=request_id, initial_tokens=list(input_ids), + num_children=self.num_return_sequences - 1, record_timestamps=record_timestamps, tokens_to_process=list(input_ids), max_new_tokens=max_new_tokens, @@ -1226,24 +1261,26 @@ def generate_batch( # Initialize manager with the batch inputs results = {} - num_requests = len(inputs) - with ( - self.continuous_batching_context_manager( - generation_config=generation_config, - num_q_cuda_graphs=num_q_padding_intervals, - num_kv_cuda_graphs=num_kv_padding_intervals, - allow_block_sharing=allow_block_sharing, - block=True, - timeout=5, - ) as manager, - logging_redirect_tqdm([logger]), - tqdm( - total=num_requests, - disable=(not progress_bar), - desc=f"Solving {num_requests} requests", - unit="request", - ) as pbar, - ): + gen_cfg = self.generation_config if generation_config is None else generation_config + num_requests = len(inputs) * gen_cfg.num_return_sequences + # Prepare context managers for the main loop + manager_cm = self.continuous_batching_context_manager( + generation_config=generation_config, + num_q_cuda_graphs=num_q_padding_intervals, + num_kv_cuda_graphs=num_kv_padding_intervals, + allow_block_sharing=allow_block_sharing, + block=True, + timeout=5, + ) + logging_cm = logging_redirect_tqdm([logger]) + pbar_cm = tqdm( + total=num_requests, + disable=(not progress_bar), + desc=f"Solving {num_requests} requests", + unit="request", + ) + # Main loop + with manager_cm as manager, logging_cm, pbar_cm as pbar: try: manager.add_requests( inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index 0dd1b0b2ce75..8435164db6d3 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -101,6 +101,8 @@ class RequestState: Attributes: request_id (str): The ID of the generation request. + initial_tokens (list[int]): The initial prompt tokens. + num_children (int): The number of children requests full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. prompt_ids (list[int] | None): The tokens IDs currently being processed. remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). @@ -121,6 +123,7 @@ class RequestState: initial_tokens: list[int] # Initial prompt tokens # Optional fields record_timestamps: bool = False # Whether to record timestamps for the generated tokens + num_children: int = 0 # Number of children requests # Internal fields tokens_to_process: list[int] | None = None # Tokens IDs currently being processed remaining_prefill_tokens: list[int] = field(default_factory=list) # For split requests, prefill left to process @@ -181,7 +184,7 @@ def update_and_check_completion(self, token_id: int) -> bool: Returns: bool: True if the request is now complete, False otherwise """ - # Only update if we're in decoding state + # Only update if we're in decoding state # TODO: seems useless (always true) -- remove this if self.status != RequestStatus.DECODING: return False @@ -227,3 +230,27 @@ def to_generation_output(self): error=self.error, timestamps=self.timestamps, ) + + def fork(self, new_request_id: str) -> "RequestState": + """Fork the request into a new request with the same state expect for request_id, created_time and lifespan.""" + t = time.perf_counter() + new_request = RequestState( + request_id=new_request_id, + initial_tokens=self.initial_tokens, + num_children=self.num_children, + tokens_to_process=self.tokens_to_process[:], + remaining_prefill_tokens=self.remaining_prefill_tokens[:], + generated_tokens=self.generated_tokens[:], + allocated_blocks=self.allocated_blocks, + position_offset=self.position_offset, + status=self.status, + max_new_tokens=self.max_new_tokens, + eos_token_id=self.eos_token_id, + streaming=self.streaming, + created_time=t, + lifespan=(t, -1), + timestamps=None if self.timestamps is None else self.timestamps[:], + error=self.error, + record_timestamps=self.record_timestamps, + ) + return new_request diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py index 109cb635ea5c..42bd607a216f 100644 --- a/src/transformers/generation/continuous_batching/scheduler.py +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -36,6 +36,7 @@ def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = Fa self.retain_cache_on_finish = retain_cache_on_finish self._cancellation_lock = threading.Lock() self._requests_to_cancel: set[str] = set() + self._requests_to_fork: list[RequestState] = [] @traced def add_waiting_request(self, state: RequestState): @@ -151,8 +152,13 @@ def _prepare_request_for_processing( else: request_tokens = state.tokens_to_process + # If the request has one or more children we make sure not to prefill it entrirely + if state.num_children > 0 and token_budget >= len(request_tokens) - 1: + token_budget = len(request_tokens) - 1 + self._requests_to_fork.append(state) + + # Case: we can process the entire prompt/remainder if len(request_tokens) < token_budget: - # Can process the entire prompt/remainder if state.status == RequestStatus.PENDING: self.active_requests[state.request_id] = state state.status = RequestStatus.PREFILLING @@ -161,8 +167,9 @@ def _prepare_request_for_processing( state.status = RequestStatus.PREFILLING state.tokens_to_process = state.remaining_prefill_tokens state.remaining_prefill_tokens = [] + + # Otherwise: we need to split the request else: - # Need to split the request if state.status == RequestStatus.PENDING: self.active_requests[state.request_id] = state state.status = RequestStatus.PREFILLING_SPLIT diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ea0590c108b5..0e3c20aff2b6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1305,7 +1305,7 @@ def _get_logits_processor( if generation_config.do_sample: # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) - if generation_config.num_beams > 1: + if generation_config.num_beams is not None and generation_config.num_beams > 1: if isinstance(generation_config._eos_token_tensor, list): min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 elif isinstance(generation_config._eos_token_tensor, torch.Tensor): diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 4bc7fca3204a..3b803708d02b 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -212,7 +212,6 @@ def _test_continuous_batching_parity( chats = [[{"role": "user", "content": user_message}] for user_message in user_messages] tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats] input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized] - print(f"{input_ids[0] = } {type(input_ids[0]) = }") # Eager and SDPA implementations get a precision boost to account for the fact that an attention mask is used in # continuous batching but not in generate @@ -504,3 +503,43 @@ def test_block_sharing_with_hybrid_model(self) -> None: }).get_expectation() # fmt: skip return self._test_block_sharing(model_id, num_layer_groups, input_msg, expected_generated_tokens) + + @parameterized.expand([True, False]) + @require_torch_accelerator + def test_num_return_sequences(self, allow_block_sharing: bool) -> None: + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + user_messages = [ + "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?" + ] + chats = [[{"role": "user", "content": user_message}] for user_message in user_messages] + tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats] + input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized] + + # Generation with continuous batching + model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa") + model = model.to(torch_device).eval() + model.generation_config.max_new_tokens = 30 + model.generation_config.do_sample = False + + # Generation with continuous batching + manager_cm = model.continuous_batching_context_manager( + allow_block_sharing=allow_block_sharing, block=True, timeout=5 + ) + # Main loop + results = [] + with manager_cm as manager: + manager.num_return_sequences = 2 + manager.add_requests(inputs=input_ids, max_new_tokens=30) + requests_left = 2 + while requests_left: + result = manager.get_result(timeout=1) + if result and result.is_finished(): + results.append(result) + requests_left -= 1 + else: + if not manager.is_running(): + break + + self.assertEqual(len(results), 2, f"Expected 2 results, but got {len(results) = }") + self.assertEqual(results[0].generated_tokens, results[1].generated_tokens)