From 9495e0e196ec27b733d50159df5badc9f92d1d74 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 18 Aug 2025 14:39:03 +0000 Subject: [PATCH 01/26] Rework of the CB example --- examples/pytorch/continuous_batching.py | 165 +++++++++++++++--------- 1 file changed, 106 insertions(+), 59 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 821e3d9a271b..e94005e5bf91 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -1,5 +1,5 @@ import time - +import argparse import datasets import torch @@ -7,71 +7,118 @@ from transformers.generation import GenerationConfig -torch.set_float32_matmul_precision("high") +def batch_generate( + model: AutoModelForCausalLM, + simple_batch_inputs: list, + generation_config: GenerationConfig, + tokenizer: AutoTokenizer, + displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs +) -> tuple[float, float]: + + # Actual batch generation + if displayed_samples >= 0: + print("--- Running CB Generation Example ---") + start_time_simple = time.time() + batch_outputs = model.generate_batch( + inputs=simple_batch_inputs, + generation_config=generation_config, + ) + end_time_simple = time.time() + if displayed_samples >= 0: + print("Done with batch generation.") + + # Decode outputs + token_count = 0 + for i, request in enumerate(batch_outputs): + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) + try: + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) + token_count += len(batch_outputs[request].generated_tokens[1:]) + except Exception as e: + print(f"Decoding failed for request {request}: {e}") + token_count += len(batch_outputs[request].generated_tokens[1:]) + output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) + if i < displayed_samples: + if len(output_text) > 0: + print("-" * 20) + print(f"{request} Input: {input_text}") + print(f"{request} Output: {output_text}") + else: + print(f"{request} Input: {input_text}") + print("[WARN]") + print(f"{request} Output was empty!") + + # Compute stats and maybe print them + gen_time = end_time_simple - start_time_simple + tok_per_sec = token_count / gen_time + if displayed_samples >= 0: + print("-" * 20) + print("--- Finished CB Generation Example ---\n") + print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s") + return gen_time, tok_per_sec + + +if __name__ == "__main__": + + # Parse args + parser = argparse.ArgumentParser() + parser.add_argument("--attn-implementation", type=str, default="paged_attention|kernels-community/flash-attn") + parser.add_argument("--matmul-precision", type=str, default="high") # set to "none" to disable + parser.add_argument("--samples", type=int, default=500) + parser.add_argument("--use-cuda-graph", action="store_true") + args = parser.parse_args() -model_id = "meta-llama/Llama-3.2-3b-Instruct" -model = ( - AutoModelForCausalLM.from_pretrained( + # Set matmul precision + if args.matmul_precision != "none": + torch.set_float32_matmul_precision(args.matmul_precision) + + # Prepare model + model_id = "meta-llama/Llama-3.2-3b-Instruct" + model = AutoModelForCausalLM.from_pretrained( model_id, - attn_implementation="paged_attention|kernels-community/flash-attn", + attn_implementation=args.attn_implementation, dtype=torch.bfloat16, + torch_dtype=torch.bfloat16, + ) + model = model.cuda().eval() + model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + + # Prepare tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") + dataset = dataset.select(range(args.samples)) # Use only 5 examples for the simple version + tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True) + simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + + # Prepare generation config + generation_config = GenerationConfig( + max_new_tokens=512, + use_cuda_graph=args.use_cuda_graph, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, ) - .eval() - .cuda() -) -tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") - -generation_config = GenerationConfig( - max_new_tokens=512, - # use_cuda_graph=False, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=False, -) - -train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") -train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version -print("--- Running CB Generation Example ---") - - -def tokenize_function(examples): - return tokenizer(examples["question"]) - - -tokenized_datasets = train_dataset.map(tokenize_function, batched=True) -simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] - -start_time_simple = time.time() -model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") -batch_outputs = model.generate_batch( - inputs=simple_batch_inputs, - generation_config=generation_config, -) -end_time_simple = time.time() -token_count = 0 -for request in batch_outputs: - input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) - try: - output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) - token_count += len(batch_outputs[request].generated_tokens[1:]) - except Exception as e: - print(f"Decoding failed for request {request}: {e}") - token_count += len(batch_outputs[request].generated_tokens[1:]) - output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) - if len(output_text) > 0: - print("-" * 20) - print(f"{request} Input: {input_text}") - print(f"{request} Output: {output_text}") - else: - print("", end="\r\r\r\r") -print("-" * 20) -print("--- Finished CB Generation Example ---\n\n") + # Run warmup batch generation + batch_generate( + model, + simple_batch_inputs[:min(5, args.samples)], + generation_config, + tokenizer, + displayed_samples=-1, + ) + + # Run batch generation + gen_time, tok_per_sec = batch_generate( + model, + simple_batch_inputs, + generation_config, + tokenizer, + displayed_samples=5, + ) -print( - f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s" -) +# TODO: remove this or incorporate it into the script above # train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version From cce99f0794cd0dea4cfbab3fe2406b47fe4a322a Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 20 Aug 2025 16:19:02 +0000 Subject: [PATCH 02/26] Further rework of CB example --- examples/pytorch/continuous_batching.py | 80 ++++++++++--------------- 1 file changed, 33 insertions(+), 47 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index e94005e5bf91..3335ba4ffa6b 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -2,6 +2,7 @@ import argparse import datasets import torch +import json from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig @@ -13,6 +14,7 @@ def batch_generate( generation_config: GenerationConfig, tokenizer: AutoTokenizer, displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs + output_file: str = None, ) -> tuple[float, float]: # Actual batch generation @@ -29,15 +31,17 @@ def batch_generate( # Decode outputs token_count = 0 + data = [] for i, request in enumerate(batch_outputs): input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) + data.append({"input": input_text}) try: output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) token_count += len(batch_outputs[request].generated_tokens[1:]) + data[-1]["output"] = output_text except Exception as e: print(f"Decoding failed for request {request}: {e}") - token_count += len(batch_outputs[request].generated_tokens[1:]) - output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) + data[-1]["output"] = "__ERROR__" if i < displayed_samples: if len(output_text) > 0: print("-" * 20) @@ -48,6 +52,12 @@ def batch_generate( print("[WARN]") print(f"{request} Output was empty!") + # If an output file is provided, save the reordered data to it + data.sort(key=lambda x: x["input"]) + if output_file is not None: + with open(output_file, "w") as f: + json.dump(data, f, indent=4) + # Compute stats and maybe print them gen_time = end_time_simple - start_time_simple tok_per_sec = token_count / gen_time @@ -62,10 +72,16 @@ def batch_generate( # Parse args parser = argparse.ArgumentParser() - parser.add_argument("--attn-implementation", type=str, default="paged_attention|kernels-community/flash-attn") - parser.add_argument("--matmul-precision", type=str, default="high") # set to "none" to disable + parser.add_argument("--num-blocks", type=int, default=None) + parser.add_argument("--max-batch-tokens", type=int, default=None) + + parser.add_argument("--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation") + parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable + parser.add_argument("--use-cuda-graph", action="store_true", default=False) + parser.add_argument("--samples", type=int, default=500) - parser.add_argument("--use-cuda-graph", action="store_true") + parser.add_argument("--displayed", type=int, default=1, help="Number of samples to display") + parser.add_argument("--output-file", type=str, default=None) args = parser.parse_args() # Set matmul precision @@ -76,12 +92,12 @@ def batch_generate( model_id = "meta-llama/Llama-3.2-3b-Instruct" model = AutoModelForCausalLM.from_pretrained( model_id, - attn_implementation=args.attn_implementation, + attn_implementation=args.attn, dtype=torch.bfloat16, torch_dtype=torch.bfloat16, ) model = model.cuda().eval() - model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + # model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") # Prepare tokenizer and dataset tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") @@ -97,6 +113,8 @@ def batch_generate( eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, do_sample=False, + num_blocks=args.num_blocks, + max_batch_tokens=args.max_batch_tokens, ) # Run warmup batch generation @@ -114,48 +132,16 @@ def batch_generate( simple_batch_inputs, generation_config, tokenizer, - displayed_samples=5, + displayed_samples=args.displayed, + output_file=args.output_file, ) -# TODO: remove this or incorporate it into the script above - -# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version - -# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512) -# simple_batch_inputs = list(tokenized_test_prompts["input_ids"]) - -# def tokenize_function(examples): -# # Truncate to avoid overly long prompts exceeding max context length -# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512) - - -# tokenized_datasets = train_dataset.map(tokenize_function, batched=True) -# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] - - -# model.config.attn_implementation = "sdpa" -# start_time_simple = time.time() -# batch_size = 64 -# full_outputs = [] -# from tqdm import tqdm - -# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)): -# outputs = model.generate( -# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device), -# generation_config=GenerationConfig( -# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id -# ), -# ) -# full_outputs.extend(outputs.tolist()) +# python examples/pytorch/continuous_batching.py --attn sdpa_paged --matmul-precision none --samples 50 --displayed 0 +# Using calculated self.num_blocks = 4096, self.block_size = 32, self.max_batch_tokens = 2048 +# CB generation took: 18.80 seconds for 13775 tokens. 732.74tok/s -# end_time_simple = time.time() -# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds") -# print("\nResults from simple generate_batch:") -# for i, request in enumerate(full_outputs): -# output_text = tokenizer.decode(request, skip_special_tokens=False) -# print("-" * 20) -# print(f" Output: {output_text}") -# print("-" * 20) -# print("--- Finished Simple Batch Generation Example ---\n\n") +# python examples/pytorch/continuous_batching.py --attn sdpa_paged --matmul-precision none --samples 100 --displayed 1 +# Setting up static tensors with T = 4096, max_token_budget = 524288, 139538202624 bytes available +# CB generation took: 29.53 seconds for 26384 tokens. 893.41tok/s From 74a0d73d5f9bd8de1bf503eb0f04bcddc6034214 Mon Sep 17 00:00:00 2001 From: remi-or Date: Wed, 20 Aug 2025 18:59:35 +0000 Subject: [PATCH 03/26] Refactor PA cache, slice on tokens, add debug prints -- WIP --- examples/pytorch/continuous_batching.py | 4 + .../generation/cb/memory_management.py | 192 ++++++++++++++ .../generation/continuous_batching.py | 251 +++++++++--------- 3 files changed, 316 insertions(+), 131 deletions(-) create mode 100644 src/transformers/generation/cb/memory_management.py diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 3335ba4ffa6b..6cd92c756b39 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -145,3 +145,7 @@ def batch_generate( # python examples/pytorch/continuous_batching.py --attn sdpa_paged --matmul-precision none --samples 100 --displayed 1 # Setting up static tensors with T = 4096, max_token_budget = 524288, 139538202624 bytes available # CB generation took: 29.53 seconds for 26384 tokens. 893.41tok/s + +# Without changes to continuous_batching.py +# Using calculated num_blocks=369, block_size=32, max concurrent requests 23 +# CB generation took: 79.58 seconds for 25813 tokens. 324.38tok/s diff --git a/src/transformers/generation/cb/memory_management.py b/src/transformers/generation/cb/memory_management.py new file mode 100644 index 000000000000..d014eb097153 --- /dev/null +++ b/src/transformers/generation/cb/memory_management.py @@ -0,0 +1,192 @@ +from typing import Optional +from math import sqrt, floor +import torch +from ...utils.logging import logging +from ...utils.metrics import traced + + +logger = logging.getLogger(__name__) + + +class PagedAttentionMemoryHandler: + + _activation_dtype = torch.bfloat16 + _activation_safety_factor = 2 + _input_dtype = torch.int32 + _upper_bound_max_batch_tokens = 1024 + _upper_bound_num_blocks = 1024 + + def __init__( + self, + block_size: int, + head_dim: int, + num_heads: int, + num_layers: int, + hidden_size: int, + vocab_size: int, + ) -> None: + self.block_size = block_size + self.head_dim = head_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + @staticmethod + def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: + if torch.cuda.is_available(): + device = torch.device("cuda") + torch.cuda.empty_cache() + torch.cuda.synchronize() + total_memory = torch.cuda.get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved(device) + allocated_memory = torch.cuda.memory_allocated(device) + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + # MPS memory reporting (PyTorch 2.0+) + total_memory = torch.mps.driver_allocated_memory() + allocated_memory = total_memory - torch.mps.recommended_max_memory() + reserved_memory = 0 # MPS does not track reserved separately + else: + device = torch.device("cpu") + total_memory = None + reserved_memory = 0 + allocated_memory = 0 + return device, total_memory, reserved_memory, allocated_memory + + @staticmethod + def get_available_memory(max_memory_percent: float = 1.0) -> int: + _, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() + available_memory = total - max(allocated, reserved) + available_memory = int(available_memory * max_memory_percent) + return available_memory + + def infer_num_blocks_and_max_batch_tokens( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int]: + # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial + if num_blocks is None and max_batch_tokens is None: + num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(max_memory_percent, cache_dtype) + # If only num_blocks is provided, we infer the max_batch_tokens + elif num_blocks is not None and max_batch_tokens is None: + max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype) + # If only max_batch_tokens is provided, we infer the num_blocks + elif max_batch_tokens is not None and num_blocks is None: + num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype) + + # We check if the memory footprint is too large in all cases + available_memory = self.get_available_memory(max_memory_percent) + memory_footprint = self.compute_memory_footprint( + max_batch_tokens=max_batch_tokens, + num_blocks=num_blocks, + cache_dtype=cache_dtype, + ) + logger.warning(f"{available_memory = }, {memory_footprint = }, {num_blocks = }, {max_batch_tokens = }") + if sum(memory_footprint) > available_memory: + raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}") + return num_blocks, max_batch_tokens + + + def compute_num_blocks_and_max_batch_tokens( + self, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + m: float = 0.1, + ) -> tuple[int, int]: + cache_memory = self.get_available_memory(max_memory_percent) + + # Compute second-degree polynomial coefficients + a = m * self._activation_dtype.itemsize + b = 8 * m * self._input_dtype.itemsize + b += 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + c = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + c += 2 * self._input_dtype.itemsize + c -= cache_memory + + # Compute discriminant and greatest solution + discriminant = b**2 - 4 * a * c + if discriminant < 0: + raise ValueError(f"Discriminant is negative: {discriminant = }") + greatest_solution = (-b + sqrt(discriminant)) / (2 * a) + if greatest_solution < 0: + raise ValueError(f"Greatest solution is negative: {greatest_solution = }") + + # Infer number of blocks and max batch tokens + num_blocks = int(greatest_solution) // self.block_size + if num_blocks > self._upper_bound_num_blocks: + logger.warning(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }") + num_blocks = self._upper_bound_num_blocks + max_batch_tokens = int(greatest_solution * m) + if max_batch_tokens > self._upper_bound_max_batch_tokens: + logger.warning(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }") + max_batch_tokens = self._upper_bound_max_batch_tokens + return num_blocks, max_batch_tokens + + def compute_max_batch_tokens( + self, + num_blocks: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + cache_memory = self.get_available_memory(max_memory_percent) + cache_size = num_blocks * self.block_size + # Compute numerator + num = cache_memory + num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + num -= 2 * self._input_dtype.itemsize + num -= cache_size * 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + # Compute denominator + denum = 8 * self._input_dtype.itemsize + cache_size * self._activation_dtype.itemsize + # Compute max batch tokens and return + return int(num / denum) + + def compute_num_blocks( + self, + max_batch_tokens: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + cache_memory = self.get_available_memory(max_memory_percent) + # Compute numerator + num = cache_memory + num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + num -= 8 * max_batch_tokens * self._input_dtype.itemsize + num -= 2 * self._input_dtype.itemsize + # Compute denominator + denum = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + denum += max_batch_tokens * self._activation_dtype.itemsize + # Compute cache size and return number of blocks + cache_size = int(num / denum) + return floor(cache_size / self.block_size) + + def compute_memory_footprint( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int, int]: + # Compute activation memory footprint + activation_memory_footprint = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) + activation_memory_footprint *= self._activation_safety_factor + # Compute cache memory footprint if num_blocks is provided + if num_blocks is not None: + cache_size = num_blocks * self.block_size + bytes_per_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + cache_memory_footprint = cache_size * bytes_per_token + else: + cache_memory_footprint = -1 + # Compute static tensors memory footprint if num_blocks and max_batch_tokens is provided + if num_blocks is not None and max_batch_tokens is not None: + static_memory_footprint = sum([ + 3 * max_batch_tokens * self._input_dtype.itemsize, # input_ids, position_ids, output_ids + max_batch_tokens * cache_size * self._activation_dtype.itemsize, # attention_mask + 2 * (max_batch_tokens + 1) * self._input_dtype.itemsize, # cumulative_seqlens_qk + 3 * max_batch_tokens * self._input_dtype.itemsize, # write_index, read_index, logits_indices + ]) + else: + static_memory_footprint = -1 + return activation_memory_footprint, cache_memory_footprint, static_memory_footprint diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 2d903da05b62..0bd925b5f6d0 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -21,7 +21,7 @@ from dataclasses import dataclass, field from enum import Enum from functools import partial -from typing import Optional, Union +from typing import Any, Optional, TypeVar, Union import torch import torch.nn as nn @@ -33,6 +33,7 @@ from ..tokenization_utils_fast import PreTrainedTokenizerFast from ..utils.logging import logging from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced +from .cb.memory_management import PagedAttentionMemoryHandler class RequestStatus(Enum): @@ -60,6 +61,9 @@ class GenerationOutput: generated_tokens (list[int]): The generated tokens. logprobs (list[float]): The log probabilities of the generated tokens. error (Optional[str]): Any error message associated with the request. When None, the request was successful. + status (RequestStatus): The status of the request. + created_time (float): The time the request was created. + next_token (Optional[int]): The next token to be generated. """ request_id: str @@ -77,24 +81,36 @@ class RequestState: """Tracks the state of a generation request through its lifecycle. Attributes: - status (RequestStatus): can be one of PENDING, PREFILLING, PREFILLING_SPLIT, + request_id (str): The ID of the generation request. + full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. + prompt_ids (list[int] | None): The tokens IDs currently being processed. + remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). + static_outputs (list[int]): The generated tokens. + allocated_blocks (list[int]): The identifiers of the allocated blocks to the request. + position_offset (int): The current position in the sequence for position_ids. + status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT, SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED + max_new_tokens (int): The maximum number of new tokens to generate. + eos_token_id (int): The ID of the end-of-sequence token. + created_time (float): The time the request was created. + error (Optional[str]): Any error message associated with the request. When None, has had no error yet. + next_token (Optional[str]): The next token to be generated. """ # Required fields request_id: str - prompt_ids: Optional[list[int]] = None # the one being processed - full_prompt_ids: Optional[list[int]] = None # the full prompt - remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests - static_outputs: list[int] = field(default_factory=list) - allocated_blocks: list[int] = field(default_factory=list) + full_prompt_ids: Optional[list[int]] = None # Full initial prompt + prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated) + remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process + static_outputs: list[int] = field(default_factory=list) # Generated tokens + allocated_blocks: list[int] = field(default_factory=list) # Block IDs allocated to the request position_offset: int = 0 # Current position in the sequence for position_ids - status: RequestStatus = RequestStatus.PENDING - max_new_tokens: int = 20 - eos_token_id: int = -1 - created_time: float = field(default_factory=time.time) - error: Optional[str] = None - next_token: Optional[str] = None + status: RequestStatus = RequestStatus.PENDING # Status of the request + max_new_tokens: int = 20 # Maximum number of new tokens to generate + eos_token_id: int = -1 # ID of the end-of-sequence token + created_time: float = field(default_factory=time.time) # Time the request was created + error: Optional[str] = None # Error message if the request failed + next_token: Optional[str] = None # Next token to be generated def current_len(self) -> int: """Get the current length of the sequence (prompt + generated tokens).""" @@ -104,6 +120,7 @@ def generated_len(self) -> int: """Get the number of tokens generated so far.""" return len(self.static_outputs) + # TODO: this logic seems one token off, check it out @traced def update_with_token(self, token_id: int) -> bool: """Update the request with a newly generated token and check for completion. @@ -147,6 +164,12 @@ def to_generation_output(self): ) +T = TypeVar("T") +def getattr_no_none(obj: Any, attr: str, default: T) -> T: + x = getattr(obj, attr, None) + return x if x is not None else default + + @attach_tracer() class PagedAttentionCache: def __init__( @@ -169,58 +192,54 @@ def __init__( layer_device_map: Optional mapping of layer indices to devices initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size """ + self.dtype = dtype + self.device = device + # Extract model dimensions - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - num_key_value_heads = self.num_key_value_heads + self.num_key_value_heads: int = getattr_no_none(config, "num_key_value_heads", config.num_attention_heads) + self.head_dim: int = getattr_no_none(config, "head_dim", config.hidden_size // config.num_attention_heads) + + self.num_hidden_layers = config.num_hidden_layers + self.block_size = getattr(generation_config, "block_size", 32) + + # Handle TP if tp_size is not None and tp_size > 1: - if num_key_value_heads % tp_size != 0: + if self.num_key_value_heads % tp_size != 0: raise ValueError( - f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." ) # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - # self.num_key_value_heads //= tp_size - - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") and config.head_dim is not None - else config.hidden_size // config.num_attention_heads + # self.num_key_value_heads //= tp_size # TODO: why is this commented out? + + # Infer number of blocks and max batch tokens + memory_handler = PagedAttentionMemoryHandler( + block_size=self.block_size, + head_dim=self.head_dim, + num_heads=self.num_key_value_heads, + num_layers=self.num_hidden_layers, + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, ) - self.num_hidden_layers = config.num_hidden_layers - - # Calculate optimal block size and number if not provided - num_blocks = getattr(generation_config, "num_blocks", 1024) - block_size = getattr(generation_config, "block_size", 32) - max_memory_percent = getattr(generation_config, "max_memory", 0.9) - max_batch_tokens = getattr(generation_config, "max_batch_tokens", 256) - if num_blocks is None or max_batch_tokens is None: - num_blocks, max_batch_tokens = compute_optimal_blocks( - generation_config.max_new_tokens, - block_size=block_size, - head_dim=self.head_dim, - num_layers=self.num_hidden_layers, - num_heads=self.num_key_value_heads, - max_memory_percent=max_memory_percent, - dtype=dtype, - num_blocks=num_blocks, - ) - logger.warning( - f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}" + num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( + num_blocks=getattr(generation_config, "num_blocks", None), + max_batch_tokens=getattr(generation_config, "max_batch_tokens", None), + max_memory_percent=getattr(generation_config, "max_memory", 0.9), + cache_dtype=self.dtype, ) - self.max_batch_tokens = max_batch_tokens - self.block_size = block_size - self.num_blocks = num_blocks - self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim) - self.dtype = dtype - self.device = device + # Add the infered attributes to the class + self.num_blocks = num_blocks + self.max_batch_tokens = max_batch_tokens + logger.warning(f"Using calculated {self.num_blocks = }, {self.block_size = }, {self.max_batch_tokens = }") + logger.warning(f"Using {self.num_key_value_heads = }, {self.head_dim = }, {self.num_hidden_layers = }") + # Initialize the cache + self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] for idx in range(config.num_hidden_layers): + available_memory = memory_handler.get_available_memory() // 1024 + logger.warning(f"Initializing cache for layer {idx}, available memory: {available_memory} kb") layer_device = layer_device_map[idx] if layer_device_map is not None else device new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) @@ -603,30 +622,6 @@ def finish_request(self, request_id: str, evict_from_cache: bool = True): del self.active_requests[request_id] -def get_device_and_memory(): - # Select best available device - if torch.cuda.is_available(): - device = torch.device("cuda") - total_memory = torch.cuda.get_device_properties(device).total_memory - reserved_memory = torch.cuda.memory_reserved(device) - allocated_memory = torch.cuda.memory_allocated(device) - - elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): - device = torch.device("mps") - # MPS memory reporting (PyTorch 2.0+) - total_memory = torch.mps.driver_allocated_memory() - allocated_memory = total_memory - torch.mps.recommended_max_memory() - reserved_memory = 0 # MPS does not track reserved separately - - else: - device = torch.device("cpu") - total_memory = None - reserved_memory = 0 - allocated_memory = 0 - - return device, total_memory, reserved_memory, allocated_memory - - @traced(standalone=True) def compute_optimal_blocks( max_num_tokens, @@ -638,7 +633,7 @@ def compute_optimal_blocks( num_blocks=None, dtype=torch.float16, ): - device, total, reserved, allocated = get_device_and_memory() + device, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() available_memory = int((total - max(allocated, reserved)) * max_memory_percent) dtype_size = torch.tensor([], dtype=dtype).element_size() @@ -676,36 +671,6 @@ class PagedAttentionArgs: use_cache: bool = False -@traced -def create_document_mask(cumulative_seqlens_q, cumulative_seqlens_k): - # Number of documents - valid_docs_q = cumulative_seqlens_q[1:] > cumulative_seqlens_q[:-1] - valid_docs_k = cumulative_seqlens_k[1:] > cumulative_seqlens_k[:-1] - num_valid_docs = min(valid_docs_q.sum(), valid_docs_k.sum()) - - # Trim to valid docs - cumulative_seqlens_q = cumulative_seqlens_q[: num_valid_docs + 1] - cumulative_seqlens_k = cumulative_seqlens_k[: num_valid_docs + 1] - - total_q = cumulative_seqlens_q[-1] - total_k = cumulative_seqlens_k[-1] - - q_indices = torch.arange(total_q, device=cumulative_seqlens_q.device) - k_indices = torch.arange(total_k, device=cumulative_seqlens_k.device) - - q_doc_ids = torch.bucketize(q_indices, cumulative_seqlens_q[1:], right=True) - k_doc_ids = torch.bucketize(k_indices, cumulative_seqlens_k[1:], right=False) - doc_mask = q_doc_ids[:, None] == k_doc_ids[None, :] - # apply causal mask where no decoding (same nb of q than k) - - is_causal = ~(cumulative_seqlens_q[1:] - cumulative_seqlens_q[:-1] == 1) * cumulative_seqlens_q[1:] - apply_causal = torch.bucketize(q_indices, is_causal, right=True)[:, None] == k_doc_ids - # TODO don't apply on prefill splitting - causal_mask = torch.triu(torch.ones(total_q, total_k, device=q_doc_ids.device), diagonal=1).bool() - doc_mask.masked_fill_((apply_causal & causal_mask), False) - return doc_mask - - # Continuous Batch Processor (Internal Logic) @attach_tracer() class ContinuousBatchProcessor: @@ -763,49 +728,68 @@ def setup_static_tensors(self): T = self.max_batch_tokens max_token_budget = self.cache.num_blocks * self.cache.block_size tensor_metadata = {"dtype": torch.int32, "device": self.model_device} + + available_memory = PagedAttentionMemoryHandler.get_available_memory() + print(f"Setting up static tensors with {T = }, {max_token_budget = }, {available_memory} bytes available") + self.tensor_metadata = tensor_metadata self.input_ids = torch.zeros((1, T), **tensor_metadata) self.position_ids = torch.zeros((1, T), **tensor_metadata) - self.attention_mask = torch.zeros( - (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device + self.attention_mask = torch.full( + (1, 1, T, max_token_budget), + torch.finfo(self.model_dtype).min, + dtype=self.model_dtype, + device=self.model_device ) self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata) self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata) - self.write_index = torch.zeros((T,), **tensor_metadata) - self.read_index = torch.zeros((max_token_budget,), **tensor_metadata) + self.write_index = torch.full((T,), -1, **tensor_metadata) + self.read_index = torch.full((max_token_budget,), -1, **tensor_metadata) self.logits_indices = torch.full((T,), -1, **tensor_metadata) self.max_seqlen_q = 0 self.max_seqlen_k = 0 self.output_ids = torch.full((1, T), -1, **tensor_metadata) + torch.cuda.synchronize() + available_memory = PagedAttentionMemoryHandler.get_available_memory() + print("Allocated static tensors,", available_memory, "bytes available") + # self.actual_tokens = T + # self.cache_used = max_token_budget + self.actual_tokens = 0 + self.cache_used = 0 @traced @torch.no_grad() def reset_static_tensors(self): """Reset static tensors for the next batch.""" - self.input_ids.zero_() - self.position_ids.zero_() - self.attention_mask.fill_(torch.finfo(self.model_dtype).min) - self.cumulative_seqlens_q.zero_() - self.cumulative_seqlens_k.zero_() - self.write_index.fill_(-1) - self.read_index.fill_(-1) - self.logits_indices.fill_(-1) + t = self.actual_tokens + c = self.cache_used + self.input_ids[:, :t].zero_() + self.position_ids[:, :t].zero_() + self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min) + self.cumulative_seqlens_q[:t+1].zero_() + self.cumulative_seqlens_k[:t+1].zero_() + self.write_index[:t].fill_(-1) + self.read_index[:c].fill_(-1) + self.logits_indices[:t].fill_(-1) self.max_seqlen_q = 0 self.max_seqlen_k = 0 - self.output_ids.zero_() + self.output_ids[:, :t].fill_(-1) + def get_model_kwargs(self) -> PagedAttentionArgs: """Get model keyword arguments for the current batch.""" # torch.set_printoptions(threshold=100000,linewidth=10000) + t = self.actual_tokens + c = self.attention_mask.size(-1) # TODO: figure out why this does not work self.cache_used return { - "input_ids": self.input_ids, - "position_ids": self.position_ids, - "attention_mask": self.attention_mask, - "cu_seq_lens_q": self.cumulative_seqlens_q, - "cu_seq_lens_k": self.cumulative_seqlens_k, - "write_index": self.write_index, - "read_index": self.read_index, - "logits_indices": self.logits_indices, + "input_ids": self.input_ids[:, :t], + "position_ids": self.position_ids[:, :t], + "attention_mask": self.attention_mask[:, :, :t, :c], # NOTE: this is probably not used for paged attention + "cu_seq_lens_q": self.cumulative_seqlens_q[:t+1], + "cu_seq_lens_k": self.cumulative_seqlens_k[:t+1], + "write_index": self.write_index[:t], + "read_index": self.read_index[:c], + "logits_indices": self.logits_indices[:t], "max_seqlen_q": self.max_seqlen_q, "max_seqlen_k": self.max_seqlen_k, "block_tables": self.cache._block_tables, @@ -934,6 +918,10 @@ def _build_tensors( self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q) self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k) self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices) + + self.actual_tokens = len(input_ids) + self.cache_used = len(read_index) + min_value = torch.finfo(self.model_dtype).min if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call` for i in range(len(cumulative_seqlens_q) - 1): @@ -1249,7 +1237,8 @@ def _sample(self, batch_processor: ContinuousBatchProcessor, probs): next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) - batch_processor.output_ids.copy_(next_tokens) + tokens = batch_processor.actual_tokens + batch_processor.output_ids[:, :tokens].copy_(next_tokens) def _run_generation_loop(self): """Main processing loop running in the background thread.""" @@ -1305,7 +1294,7 @@ def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor, is_f if torch.cuda.is_available(): torch.cuda.synchronize() batch_processor.prepare_next_batch() - device, total, reserved, allocated = get_device_and_memory() + device, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") if torch.cuda.is_available() and self.use_cuda_graph: if is_first: From 79118c5346e9451b064e16b17d4ebcf96fa6f2f9 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 21 Aug 2025 00:10:35 +0000 Subject: [PATCH 04/26] Slice cache -- WIP --- src/transformers/generation/continuous_batching.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 0bd925b5f6d0..7d2ab2ba193d 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -752,8 +752,6 @@ def setup_static_tensors(self): torch.cuda.synchronize() available_memory = PagedAttentionMemoryHandler.get_available_memory() print("Allocated static tensors,", available_memory, "bytes available") - # self.actual_tokens = T - # self.cache_used = max_token_budget self.actual_tokens = 0 self.cache_used = 0 @@ -780,7 +778,7 @@ def get_model_kwargs(self) -> PagedAttentionArgs: """Get model keyword arguments for the current batch.""" # torch.set_printoptions(threshold=100000,linewidth=10000) t = self.actual_tokens - c = self.attention_mask.size(-1) # TODO: figure out why this does not work self.cache_used + c = self.cache_used return { "input_ids": self.input_ids[:, :t], "position_ids": self.position_ids[:, :t], From dc53ad6096f3cfc4ba9c4d1898ed45cf4ba34446 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 21 Aug 2025 14:22:55 +0000 Subject: [PATCH 05/26] Added a mechanism to check batched outputs in CB script --- examples/pytorch/continuous_batching.py | 60 ++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 6cd92c756b39..253a8cbba3ec 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -3,11 +3,39 @@ import datasets import torch import json +from typing import Optional from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig +MODEL_ID = "meta-llama/Llama-3.2-3b-Instruct" + + +def generate_simple(attn_implementation: str, simple_batch_inputs: list[int], generation_config: GenerationConfig) -> list[str]: + attn_implementation = { + "sdpa_paged": "sdpa", + "eager_paged": "eager", + }[attn_implementation] + + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + attn_implementation=attn_implementation, + ).cuda().eval() + + decoded_outputs = [] + for input_ids in simple_batch_inputs: + input_ids = torch.tensor([input_ids]).to("cuda") + attention_mask = torch.ones_like(input_ids) + outputs = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config) + generated_tokens = outputs[0][input_ids.shape[1]:] + decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True) + decoded_outputs.append(decoded_output) + + return decoded_outputs + + def batch_generate( model: AutoModelForCausalLM, simple_batch_inputs: list, @@ -15,6 +43,7 @@ def batch_generate( tokenizer: AutoTokenizer, displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs output_file: str = None, + expected_outputs: Optional[list[str]] = None, ) -> tuple[float, float]: # Actual batch generation @@ -33,15 +62,20 @@ def batch_generate( token_count = 0 data = [] for i, request in enumerate(batch_outputs): - input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True) data.append({"input": input_text}) + + # Try to decode the output try: - output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True) token_count += len(batch_outputs[request].generated_tokens[1:]) data[-1]["output"] = output_text except Exception as e: print(f"Decoding failed for request {request}: {e}") data[-1]["output"] = "__ERROR__" + continue + + # Display sample if asked if i < displayed_samples: if len(output_text) > 0: print("-" * 20) @@ -52,6 +86,13 @@ def batch_generate( print("[WARN]") print(f"{request} Output was empty!") + # Compare with classic generate if asked + if expected_outputs is not None: + matches = output_text == expected_outputs[i] + data[-1]["ref"] = expected_outputs[i] + data[-1]["matches"] = matches + print(f"Request {i} matches" if matches else f"Request {i} does NOT match!") + # If an output file is provided, save the reordered data to it data.sort(key=lambda x: x["input"]) if output_file is not None: @@ -80,8 +121,9 @@ def batch_generate( parser.add_argument("--use-cuda-graph", action="store_true", default=False) parser.add_argument("--samples", type=int, default=500) - parser.add_argument("--displayed", type=int, default=1, help="Number of samples to display") + parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display") parser.add_argument("--output-file", type=str, default=None) + parser.add_argument("--compare", action="store_true", default=False) args = parser.parse_args() # Set matmul precision @@ -89,9 +131,8 @@ def batch_generate( torch.set_float32_matmul_precision(args.matmul_precision) # Prepare model - model_id = "meta-llama/Llama-3.2-3b-Instruct" model = AutoModelForCausalLM.from_pretrained( - model_id, + MODEL_ID, attn_implementation=args.attn, dtype=torch.bfloat16, torch_dtype=torch.bfloat16, @@ -100,7 +141,7 @@ def batch_generate( # model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") # Prepare tokenizer and dataset - tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") dataset = dataset.select(range(args.samples)) # Use only 5 examples for the simple version tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True) @@ -117,6 +158,9 @@ def batch_generate( max_batch_tokens=args.max_batch_tokens, ) + # If we need to compare, we need to generate the reference outputs + expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None + # Run warmup batch generation batch_generate( model, @@ -134,6 +178,7 @@ def batch_generate( tokenizer, displayed_samples=args.displayed, output_file=args.output_file, + expected_outputs=expected_outputs, ) @@ -149,3 +194,6 @@ def batch_generate( # Without changes to continuous_batching.py # Using calculated num_blocks=369, block_size=32, max concurrent requests 23 # CB generation took: 79.58 seconds for 25813 tokens. 324.38tok/s + + +# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json From b107785c9ffb7630ddfbdce78506c65b7fb8c807 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 21 Aug 2025 14:24:38 +0000 Subject: [PATCH 06/26] Less logging, debug flag for slice, !better reset! -- WIP --- .../generation/cb/memory_management.py | 1 - .../generation/continuous_batching.py | 61 +++++++++---------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/src/transformers/generation/cb/memory_management.py b/src/transformers/generation/cb/memory_management.py index d014eb097153..90c282f19d44 100644 --- a/src/transformers/generation/cb/memory_management.py +++ b/src/transformers/generation/cb/memory_management.py @@ -85,7 +85,6 @@ def infer_num_blocks_and_max_batch_tokens( num_blocks=num_blocks, cache_dtype=cache_dtype, ) - logger.warning(f"{available_memory = }, {memory_footprint = }, {num_blocks = }, {max_batch_tokens = }") if sum(memory_footprint) > available_memory: raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}") return num_blocks, max_batch_tokens diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 7d2ab2ba193d..3f2be79ff580 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -230,16 +230,16 @@ def __init__( # Add the infered attributes to the class self.num_blocks = num_blocks self.max_batch_tokens = max_batch_tokens - logger.warning(f"Using calculated {self.num_blocks = }, {self.block_size = }, {self.max_batch_tokens = }") - logger.warning(f"Using {self.num_key_value_heads = }, {self.head_dim = }, {self.num_hidden_layers = }") + logger.info( + f"After init, {self.num_blocks = }, {self.block_size = }, {self.max_batch_tokens = } " + f"{self.num_key_value_heads = }, {self.head_dim = }, {self.num_hidden_layers = }" + ) # Initialize the cache self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] for idx in range(config.num_hidden_layers): - available_memory = memory_handler.get_available_memory() // 1024 - logger.warning(f"Initializing cache for layer {idx}, available memory: {available_memory} kb") layer_device = layer_device_map[idx] if layer_device_map is not None else device new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) @@ -687,6 +687,7 @@ def __init__( scheduler: Scheduler, streaming: bool = False, manual_eviction: bool = False, + slice_inputs: bool = True, # TODO: remove this once parity is ensured ): """Initialize the continuous batch processor. @@ -711,6 +712,7 @@ def __init__( self.scheduler = scheduler self.streaming = streaming self.manual_eviction = manual_eviction + self.slice_inputs = slice_inputs self.requests_in_batch: list[RequestState] = [] @@ -728,30 +730,24 @@ def setup_static_tensors(self): T = self.max_batch_tokens max_token_budget = self.cache.num_blocks * self.cache.block_size tensor_metadata = {"dtype": torch.int32, "device": self.model_device} - - available_memory = PagedAttentionMemoryHandler.get_available_memory() - print(f"Setting up static tensors with {T = }, {max_token_budget = }, {available_memory} bytes available") - + # Prepare empty tensors self.tensor_metadata = tensor_metadata - self.input_ids = torch.zeros((1, T), **tensor_metadata) - self.position_ids = torch.zeros((1, T), **tensor_metadata) - self.attention_mask = torch.full( - (1, 1, T, max_token_budget), - torch.finfo(self.model_dtype).min, - dtype=self.model_dtype, - device=self.model_device - ) - self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata) - self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata) - self.write_index = torch.full((T,), -1, **tensor_metadata) - self.read_index = torch.full((max_token_budget,), -1, **tensor_metadata) - self.logits_indices = torch.full((T,), -1, **tensor_metadata) + self.input_ids = torch.empty((1, T), **tensor_metadata) + self.position_ids = torch.empty((1, T), **tensor_metadata) + self.attention_mask = torch.empty((1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device) + self.cumulative_seqlens_q = torch.empty((T + 1,), **tensor_metadata) + self.cumulative_seqlens_k = torch.empty((T + 1,), **tensor_metadata) + self.write_index = torch.empty((T,), **tensor_metadata) + self.read_index = torch.empty((max_token_budget,), **tensor_metadata) + self.logits_indices = torch.empty((T,), **tensor_metadata) self.max_seqlen_q = 0 self.max_seqlen_k = 0 - self.output_ids = torch.full((1, T), -1, **tensor_metadata) - torch.cuda.synchronize() - available_memory = PagedAttentionMemoryHandler.get_available_memory() - print("Allocated static tensors,", available_memory, "bytes available") + self.output_ids = torch.empty((1, T), **tensor_metadata) + # Initialize the tensors by pretending they are in full use + self.actual_tokens = T + self.cache_used = max_token_budget + self.reset_static_tensors() + # Reset stats to 0 self.actual_tokens = 0 self.cache_used = 0 @@ -759,8 +755,10 @@ def setup_static_tensors(self): @torch.no_grad() def reset_static_tensors(self): """Reset static tensors for the next batch.""" - t = self.actual_tokens - c = self.cache_used + # Compute the slice to reset + t = self.actual_tokens if self.slice_inputs else self.write_index.size(0) + c = self.cache_used if self.slice_inputs else self.read_index.size(0) + # Reset the tensors self.input_ids[:, :t].zero_() self.position_ids[:, :t].zero_() self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min) @@ -776,9 +774,10 @@ def reset_static_tensors(self): def get_model_kwargs(self) -> PagedAttentionArgs: """Get model keyword arguments for the current batch.""" - # torch.set_printoptions(threshold=100000,linewidth=10000) - t = self.actual_tokens - c = self.cache_used + # Compute the slice to return + t = self.actual_tokens if self.slice_inputs else self.write_index.size(0) + c = self.cache_used if self.slice_inputs else self.read_index.size(0) + # Return the tensors return { "input_ids": self.input_ids[:, :t], "position_ids": self.position_ids[:, :t], @@ -1235,7 +1234,7 @@ def _sample(self, batch_processor: ContinuousBatchProcessor, probs): next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) - tokens = batch_processor.actual_tokens + tokens = batch_processor.actual_tokens if self.slice_inputs else batch_processor.output_ids.size(1) batch_processor.output_ids[:, :tokens].copy_(next_tokens) def _run_generation_loop(self): From bababa46018f4b345a0f608828f522749d598795 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 21 Aug 2025 21:52:54 +0000 Subject: [PATCH 07/26] QOL and safety margins --- examples/pytorch/continuous_batching.py | 29 +++++++++++----- .../generation/cb/memory_management.py | 4 +-- .../generation/continuous_batching.py | 33 +++++++++++++------ 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 253a8cbba3ec..26fa60a84559 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -93,12 +93,6 @@ def batch_generate( data[-1]["matches"] = matches print(f"Request {i} matches" if matches else f"Request {i} does NOT match!") - # If an output file is provided, save the reordered data to it - data.sort(key=lambda x: x["input"]) - if output_file is not None: - with open(output_file, "w") as f: - json.dump(data, f, indent=4) - # Compute stats and maybe print them gen_time = end_time_simple - start_time_simple tok_per_sec = token_count / gen_time @@ -106,6 +100,21 @@ def batch_generate( print("-" * 20) print("--- Finished CB Generation Example ---\n") print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s") + stats = { + "num_blocks": generation_config.num_blocks, + "max_batch_tokens": generation_config.max_batch_tokens, + "gen_time": gen_time, + "token_count": token_count, + "tok_per_sec": tok_per_sec, + } + + # If an output file is provided, save the reordered data to it + data.sort(key=lambda x: x["input"]) + data = [stats] + data + if output_file is not None: + with open(output_file, "w") as f: + json.dump(data, f, indent=4) + return gen_time, tok_per_sec @@ -113,8 +122,8 @@ def batch_generate( # Parse args parser = argparse.ArgumentParser() - parser.add_argument("--num-blocks", type=int, default=None) - parser.add_argument("--max-batch-tokens", type=int, default=None) + parser.add_argument("--num-blocks", "-n", type=int, default=None) + parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) parser.add_argument("--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation") parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable @@ -161,6 +170,10 @@ def batch_generate( # If we need to compare, we need to generate the reference outputs expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None + # If no output file is provided, we pick a name based on the args + if args.output_file is None: + args.output_file = f"cb_{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}_.json" + # Run warmup batch generation batch_generate( model, diff --git a/src/transformers/generation/cb/memory_management.py b/src/transformers/generation/cb/memory_management.py index 90c282f19d44..15fd95264653 100644 --- a/src/transformers/generation/cb/memory_management.py +++ b/src/transformers/generation/cb/memory_management.py @@ -13,8 +13,8 @@ class PagedAttentionMemoryHandler: _activation_dtype = torch.bfloat16 _activation_safety_factor = 2 _input_dtype = torch.int32 - _upper_bound_max_batch_tokens = 1024 - _upper_bound_num_blocks = 1024 + _upper_bound_max_batch_tokens = 2048 + _upper_bound_num_blocks = 16384 def __init__( self, diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 3f2be79ff580..81112864ba4a 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -49,6 +49,7 @@ class RequestStatus(Enum): logger = logging.getLogger(__name__) +# logger.setLevel(logging.INFO) @dataclass @@ -386,6 +387,10 @@ def get_active_request_static_outputs(self, request_id: str) -> list[int]: @attach_tracer() class FIFOScheduler(Scheduler): + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.1): + super().__init__(cache, retain_cache_on_finish) + self.safety_margin = safety_margin + @traced def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): # 1. we check that the occupancy is less than the requested length @@ -458,13 +463,19 @@ def schedule_batch(self, token_budget: int) -> list[RequestState]: candidates = priority_states + second_priority_states request_ids_to_remove_from_waiting = set() + safety_margins = self.safety_margin * self.cache.num_blocks for state in candidates: + + # If we are out the safety margin, we only accept decoding requests or the first prefill request + num_free_blocks = self.cache.get_num_free_blocks() + outside_safety_margin = num_free_blocks < safety_margins + if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING: + break + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) request_len = len(state.prompt_ids) - if not self._allocate_blocks_if_needed( - state, len(state.prompt_ids) - ): # don't schedule if we can't allocate blocks + if not self._allocate_blocks_if_needed(state, len(state.prompt_ids)): # don't schedule if we can't allocate blocks if len(self.cache._free_blocks) == 0: break continue @@ -1234,7 +1245,7 @@ def _sample(self, batch_processor: ContinuousBatchProcessor, probs): next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) - tokens = batch_processor.actual_tokens if self.slice_inputs else batch_processor.output_ids.size(1) + tokens = next_tokens.size(1) batch_processor.output_ids[:, :tokens].copy_(next_tokens) def _run_generation_loop(self): @@ -1274,11 +1285,10 @@ def _run_generation_loop(self): self.manual_eviction, ) self.batch_processor = batch_processor - is_first = True + self.current_batch = 0 while (not self.stop_event.is_set()) or batch_processor.has_pending_requests(): - self._inner_generation_loop(batch_processor, is_first) - if is_first: - is_first = False + self._inner_generation_loop(batch_processor) + self.current_batch += 1 except Exception as e: logger.error(f"Error in generation loop: {e}", exc_info=True) @@ -1287,14 +1297,14 @@ def _run_generation_loop(self): logger.info("Generation loop finished.") @traced(span_name="generation_loop") - def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor, is_first: bool = False): + def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor): if torch.cuda.is_available(): torch.cuda.synchronize() batch_processor.prepare_next_batch() device, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") if torch.cuda.is_available() and self.use_cuda_graph: - if is_first: + if self.current_batch == 0: self.warmup(batch_processor) elif hasattr(self, "graph"): try: @@ -1406,6 +1416,9 @@ def generate_batch( """ if not inputs: return [] + if logger.getEffectiveLevel() <= logging.INFO: + logger.warning("Progress bar is disabled when logger level is less than INFO") + progress_bar = False # Initialize manager with the batch inputs manager = self.init_continuous_batching(generation_config=generation_config) From f01e9db47f70522b0c1cfe3893fd6f907ae64fc3 Mon Sep 17 00:00:00 2001 From: remi-or Date: Fri, 22 Aug 2025 10:20:39 +0000 Subject: [PATCH 08/26] Refactor and style --- examples/pytorch/continuous_batching.py | 49 +- .../generation/cb/memory_management.py | 191 ----- .../continuous_batching/__init__.py | 6 + .../generation/continuous_batching/cache.py | 370 ++++++++++ .../continuous_api.py} | 677 +----------------- .../generation/continuous_batching/core.py | 172 +++++ .../continuous_batching/scheduler.py | 300 ++++++++ 7 files changed, 897 insertions(+), 868 deletions(-) delete mode 100644 src/transformers/generation/cb/memory_management.py create mode 100644 src/transformers/generation/continuous_batching/__init__.py create mode 100644 src/transformers/generation/continuous_batching/cache.py rename src/transformers/generation/{continuous_batching.py => continuous_batching/continuous_api.py} (53%) create mode 100644 src/transformers/generation/continuous_batching/core.py create mode 100644 src/transformers/generation/continuous_batching/scheduler.py diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 26fa60a84559..2bfcd0d3ba50 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -1,10 +1,11 @@ -import time import argparse -import datasets -import torch import json +import time from typing import Optional +import datasets +import torch + from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig @@ -12,27 +13,33 @@ MODEL_ID = "meta-llama/Llama-3.2-3b-Instruct" -def generate_simple(attn_implementation: str, simple_batch_inputs: list[int], generation_config: GenerationConfig) -> list[str]: +def generate_simple( + attn_implementation: str, simple_batch_inputs: list[int], generation_config: GenerationConfig +) -> list[str]: attn_implementation = { "sdpa_paged": "sdpa", "eager_paged": "eager", }[attn_implementation] - model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, - torch_dtype=torch.bfloat16, - attn_implementation=attn_implementation, - ).cuda().eval() - + model = ( + AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + attn_implementation=attn_implementation, + ) + .cuda() + .eval() + ) + decoded_outputs = [] for input_ids in simple_batch_inputs: input_ids = torch.tensor([input_ids]).to("cuda") attention_mask = torch.ones_like(input_ids) outputs = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config) - generated_tokens = outputs[0][input_ids.shape[1]:] + generated_tokens = outputs[0][input_ids.shape[1] :] decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True) decoded_outputs.append(decoded_output) - + return decoded_outputs @@ -41,11 +48,10 @@ def batch_generate( simple_batch_inputs: list, generation_config: GenerationConfig, tokenizer: AutoTokenizer, - displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs - output_file: str = None, + displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs + output_file: Optional[str] = None, expected_outputs: Optional[list[str]] = None, ) -> tuple[float, float]: - # Actual batch generation if displayed_samples >= 0: print("--- Running CB Generation Example ---") @@ -119,14 +125,15 @@ def batch_generate( if __name__ == "__main__": - # Parse args parser = argparse.ArgumentParser() parser.add_argument("--num-blocks", "-n", type=int, default=None) parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) - parser.add_argument("--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation") - parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable + parser.add_argument( + "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" + ) + parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable parser.add_argument("--use-cuda-graph", action="store_true", default=False) parser.add_argument("--samples", type=int, default=500) @@ -172,12 +179,14 @@ def batch_generate( # If no output file is provided, we pick a name based on the args if args.output_file is None: - args.output_file = f"cb_{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}_.json" + args.output_file = ( + f"cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json" + ) # Run warmup batch generation batch_generate( model, - simple_batch_inputs[:min(5, args.samples)], + simple_batch_inputs[: min(5, args.samples)], generation_config, tokenizer, displayed_samples=-1, diff --git a/src/transformers/generation/cb/memory_management.py b/src/transformers/generation/cb/memory_management.py deleted file mode 100644 index 15fd95264653..000000000000 --- a/src/transformers/generation/cb/memory_management.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import Optional -from math import sqrt, floor -import torch -from ...utils.logging import logging -from ...utils.metrics import traced - - -logger = logging.getLogger(__name__) - - -class PagedAttentionMemoryHandler: - - _activation_dtype = torch.bfloat16 - _activation_safety_factor = 2 - _input_dtype = torch.int32 - _upper_bound_max_batch_tokens = 2048 - _upper_bound_num_blocks = 16384 - - def __init__( - self, - block_size: int, - head_dim: int, - num_heads: int, - num_layers: int, - hidden_size: int, - vocab_size: int, - ) -> None: - self.block_size = block_size - self.head_dim = head_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.hidden_size = hidden_size - self.vocab_size = vocab_size - - @staticmethod - def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: - if torch.cuda.is_available(): - device = torch.device("cuda") - torch.cuda.empty_cache() - torch.cuda.synchronize() - total_memory = torch.cuda.get_device_properties(device).total_memory - reserved_memory = torch.cuda.memory_reserved(device) - allocated_memory = torch.cuda.memory_allocated(device) - elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): - device = torch.device("mps") - # MPS memory reporting (PyTorch 2.0+) - total_memory = torch.mps.driver_allocated_memory() - allocated_memory = total_memory - torch.mps.recommended_max_memory() - reserved_memory = 0 # MPS does not track reserved separately - else: - device = torch.device("cpu") - total_memory = None - reserved_memory = 0 - allocated_memory = 0 - return device, total_memory, reserved_memory, allocated_memory - - @staticmethod - def get_available_memory(max_memory_percent: float = 1.0) -> int: - _, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() - available_memory = total - max(allocated, reserved) - available_memory = int(available_memory * max_memory_percent) - return available_memory - - def infer_num_blocks_and_max_batch_tokens( - self, - num_blocks: Optional[int] = None, - max_batch_tokens: Optional[int] = None, - max_memory_percent: float = 0.9, - cache_dtype: torch.dtype = torch.float16, - ) -> tuple[int, int]: - # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial - if num_blocks is None and max_batch_tokens is None: - num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(max_memory_percent, cache_dtype) - # If only num_blocks is provided, we infer the max_batch_tokens - elif num_blocks is not None and max_batch_tokens is None: - max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype) - # If only max_batch_tokens is provided, we infer the num_blocks - elif max_batch_tokens is not None and num_blocks is None: - num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype) - - # We check if the memory footprint is too large in all cases - available_memory = self.get_available_memory(max_memory_percent) - memory_footprint = self.compute_memory_footprint( - max_batch_tokens=max_batch_tokens, - num_blocks=num_blocks, - cache_dtype=cache_dtype, - ) - if sum(memory_footprint) > available_memory: - raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}") - return num_blocks, max_batch_tokens - - - def compute_num_blocks_and_max_batch_tokens( - self, - max_memory_percent: float = 0.9, - cache_dtype: torch.dtype = torch.float16, - m: float = 0.1, - ) -> tuple[int, int]: - cache_memory = self.get_available_memory(max_memory_percent) - - # Compute second-degree polynomial coefficients - a = m * self._activation_dtype.itemsize - b = 8 * m * self._input_dtype.itemsize - b += 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize - c = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor - c += 2 * self._input_dtype.itemsize - c -= cache_memory - - # Compute discriminant and greatest solution - discriminant = b**2 - 4 * a * c - if discriminant < 0: - raise ValueError(f"Discriminant is negative: {discriminant = }") - greatest_solution = (-b + sqrt(discriminant)) / (2 * a) - if greatest_solution < 0: - raise ValueError(f"Greatest solution is negative: {greatest_solution = }") - - # Infer number of blocks and max batch tokens - num_blocks = int(greatest_solution) // self.block_size - if num_blocks > self._upper_bound_num_blocks: - logger.warning(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }") - num_blocks = self._upper_bound_num_blocks - max_batch_tokens = int(greatest_solution * m) - if max_batch_tokens > self._upper_bound_max_batch_tokens: - logger.warning(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }") - max_batch_tokens = self._upper_bound_max_batch_tokens - return num_blocks, max_batch_tokens - - def compute_max_batch_tokens( - self, - num_blocks: int, - max_memory_percent: float = 0.9, - cache_dtype: torch.dtype = torch.float16, - ) -> int: - cache_memory = self.get_available_memory(max_memory_percent) - cache_size = num_blocks * self.block_size - # Compute numerator - num = cache_memory - num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor - num -= 2 * self._input_dtype.itemsize - num -= cache_size * 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize - # Compute denominator - denum = 8 * self._input_dtype.itemsize + cache_size * self._activation_dtype.itemsize - # Compute max batch tokens and return - return int(num / denum) - - def compute_num_blocks( - self, - max_batch_tokens: int, - max_memory_percent: float = 0.9, - cache_dtype: torch.dtype = torch.float16, - ) -> int: - cache_memory = self.get_available_memory(max_memory_percent) - # Compute numerator - num = cache_memory - num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor - num -= 8 * max_batch_tokens * self._input_dtype.itemsize - num -= 2 * self._input_dtype.itemsize - # Compute denominator - denum = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize - denum += max_batch_tokens * self._activation_dtype.itemsize - # Compute cache size and return number of blocks - cache_size = int(num / denum) - return floor(cache_size / self.block_size) - - def compute_memory_footprint( - self, - num_blocks: Optional[int] = None, - max_batch_tokens: Optional[int] = None, - cache_dtype: torch.dtype = torch.float16, - ) -> tuple[int, int, int]: - # Compute activation memory footprint - activation_memory_footprint = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) - activation_memory_footprint *= self._activation_safety_factor - # Compute cache memory footprint if num_blocks is provided - if num_blocks is not None: - cache_size = num_blocks * self.block_size - bytes_per_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize - cache_memory_footprint = cache_size * bytes_per_token - else: - cache_memory_footprint = -1 - # Compute static tensors memory footprint if num_blocks and max_batch_tokens is provided - if num_blocks is not None and max_batch_tokens is not None: - static_memory_footprint = sum([ - 3 * max_batch_tokens * self._input_dtype.itemsize, # input_ids, position_ids, output_ids - max_batch_tokens * cache_size * self._activation_dtype.itemsize, # attention_mask - 2 * (max_batch_tokens + 1) * self._input_dtype.itemsize, # cumulative_seqlens_qk - 3 * max_batch_tokens * self._input_dtype.itemsize, # write_index, read_index, logits_indices - ]) - else: - static_memory_footprint = -1 - return activation_memory_footprint, cache_memory_footprint, static_memory_footprint diff --git a/src/transformers/generation/continuous_batching/__init__.py b/src/transformers/generation/continuous_batching/__init__.py new file mode 100644 index 000000000000..7e7f3fb7f925 --- /dev/null +++ b/src/transformers/generation/continuous_batching/__init__.py @@ -0,0 +1,6 @@ +from .cache import PagedAttentionCache +from .continuous_api import ContinuousBatchingManager, ContinuousMixin +from .core import RequestState, RequestStatus + + +__all__ = ["PagedAttentionCache", "RequestState", "RequestStatus", "ContinuousMixin", "ContinuousBatchingManager"] diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py new file mode 100644 index 000000000000..c8d263281935 --- /dev/null +++ b/src/transformers/generation/continuous_batching/cache.py @@ -0,0 +1,370 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque +from math import floor, sqrt +from typing import Any, Optional, TypeVar, Union + +import torch + +from ...configuration_utils import PretrainedConfig +from ...generation.configuration_utils import GenerationConfig +from ...utils.metrics import attach_tracer, traced +from .core import RequestState, logger + + +T = TypeVar("T") + + +def getattr_no_none(obj: Any, attr: str, default: T) -> T: + x = getattr(obj, attr, None) + return x if x is not None else default + + +@attach_tracer() +class PagedAttentionCache: + def __init__( + self, + config: PretrainedConfig, + generation_config: GenerationConfig, + device: torch.device, + dtype: torch.dtype = torch.float16, + num_requests: int = 100, + layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + tp_size: Optional[int] = None, + ) -> None: + """Initialize a paged attention cache for efficient memory usage. + + Args: + config: Model configuration + generation_config: Generation configuration containing cache parameters + device: Device for the cache tensors + dtype: Data type for the cache tensors + layer_device_map: Optional mapping of layer indices to devices + initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size + """ + self.dtype = dtype + self.device = device + + # Extract model dimensions + self.num_key_value_heads: int = getattr_no_none(config, "num_key_value_heads", config.num_attention_heads) + self.head_dim: int = getattr_no_none(config, "head_dim", config.hidden_size // config.num_attention_heads) + + self.num_hidden_layers = config.num_hidden_layers + self.block_size = getattr(generation_config, "block_size", 32) + + # Handle TP + if tp_size is not None and tp_size > 1: + if self.num_key_value_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + # self.num_key_value_heads //= tp_size # TODO: why is this commented out? + + # Infer number of blocks and max batch tokens + memory_handler = PagedAttentionMemoryHandler( + block_size=self.block_size, + head_dim=self.head_dim, + num_heads=self.num_key_value_heads, + num_layers=self.num_hidden_layers, + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + ) + num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( + num_blocks=getattr(generation_config, "num_blocks", None), + max_batch_tokens=getattr(generation_config, "max_batch_tokens", None), + max_memory_percent=getattr(generation_config, "max_memory", 0.9), + cache_dtype=self.dtype, + ) + + # Add the infered attributes to the class + self.num_blocks = num_blocks + self.max_batch_tokens = max_batch_tokens + logger.info( + f"After init, {self.num_blocks = }, {self.block_size = }, {self.max_batch_tokens = } " + f"{self.num_key_value_heads = }, {self.head_dim = }, {self.num_hidden_layers = }" + ) + + # Initialize the cache + self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + for idx in range(config.num_hidden_layers): + layer_device = layer_device_map[idx] if layer_device_map is not None else device + new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + # Block management data structures + self._free_blocks = deque(range(num_blocks)) + self._block_tables: dict[str, list[int]] = {} + + @traced + def allocate_blocks(self, n_blocks: int, request_id: str) -> list[int]: + """Allocates n_blocks for a given request_id.""" + if len(self._free_blocks) < n_blocks: + return False + + allocated = [] + for _ in range(n_blocks): + allocated.append(self._free_blocks.popleft()) + + if request_id not in self._block_tables: + self._block_tables[request_id] = [] + self._block_tables[request_id].extend(allocated) + return allocated + + @traced + def free_blocks(self, request_id: str) -> None: + """Frees all blocks associated with a request_id.""" + if request_id in self._block_tables: + blocks_to_free = self._block_tables.pop(request_id) + self._free_blocks.extend(blocks_to_free) + else: + logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}") + + def get_num_free_blocks(self) -> int: + """Returns the number of free blocks available.""" + return len(self._free_blocks) + + def get_block_table(self, request_id: str) -> list[int]: + """Returns the block table for a request.""" + return self._block_tables.get(request_id, []) + + @traced + def _get_physical_indices(self, state: RequestState, logical_indices: list[int]) -> list[int]: + """ + Maps logical sequence indices to physical cache indices using the block table, using PyTorch. + + Args: + request_id: The request ID. + logical_indices: A list of logical indices. + + Returns: + A list of physical indices. + + Raises: + ValueError: If no block table is found for the request ID. + IndexError: If a logical index maps to a block index that is out of bounds. + """ + request_id = state.request_id + block_table = self._block_tables.get(request_id) + if not block_table: + raise ValueError(f"No block table found for request {request_id}") + + block_size = self.block_size + physical_indices = [] + + for idx in logical_indices: + block_idx = idx // block_size + block_offset = idx % block_size + + if block_idx >= len(block_table): + raise IndexError( + f"Logical index {idx} maps to block index {block_idx} which is out of bounds " + f"for request {request_id}" + ) + + physical_block_num = block_table[block_idx] + physical_index = physical_block_num * block_size + block_offset + physical_indices.append(physical_index) + + return physical_indices + + @traced + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + read_index, + write_index, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Reshape cache for easier indexing + total_slots = self.num_blocks * self.block_size + k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + k_cache_flat[:, write_index, :] = key_states[0] + v_cache_flat[:, write_index, :] = value_states[0] + return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :] + + +class PagedAttentionMemoryHandler: + _activation_dtype = torch.bfloat16 + _activation_safety_factor = 2 + _input_dtype = torch.int32 + _upper_bound_max_batch_tokens = 2048 + _upper_bound_num_blocks = 16384 + + def __init__( + self, + block_size: int, + head_dim: int, + num_heads: int, + num_layers: int, + hidden_size: int, + vocab_size: int, + ) -> None: + self.block_size = block_size + self.head_dim = head_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + @staticmethod + def get_available_memory(max_memory_percent: float = 1.0) -> int: + _, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() + available_memory = total - max(allocated, reserved) + available_memory = int(available_memory * max_memory_percent) + return available_memory + + def infer_num_blocks_and_max_batch_tokens( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int]: + # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial + if num_blocks is None and max_batch_tokens is None: + num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens( + max_memory_percent, cache_dtype + ) + # If only num_blocks is provided, we infer the max_batch_tokens + elif num_blocks is not None and max_batch_tokens is None: + max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype) + # If only max_batch_tokens is provided, we infer the num_blocks + elif max_batch_tokens is not None and num_blocks is None: + num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype) + + # We check if the memory footprint is too large in all cases + available_memory = self.get_available_memory(max_memory_percent) + memory_footprint = self.compute_memory_footprint( + max_batch_tokens=max_batch_tokens, + num_blocks=num_blocks, + cache_dtype=cache_dtype, + ) + if sum(memory_footprint) > available_memory: + raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}") + return num_blocks, max_batch_tokens + + def compute_num_blocks_and_max_batch_tokens( + self, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + m: float = 0.1, + ) -> tuple[int, int]: + cache_memory = self.get_available_memory(max_memory_percent) + + # Compute second-degree polynomial coefficients + a = m * self._activation_dtype.itemsize + b = 8 * m * self._input_dtype.itemsize + b += 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + c = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + c += 2 * self._input_dtype.itemsize + c -= cache_memory + + # Compute discriminant and greatest solution + discriminant = b**2 - 4 * a * c + if discriminant < 0: + raise ValueError(f"Discriminant is negative: {discriminant = }") + greatest_solution = (-b + sqrt(discriminant)) / (2 * a) + if greatest_solution < 0: + raise ValueError(f"Greatest solution is negative: {greatest_solution = }") + + # Infer number of blocks and max batch tokens + num_blocks = int(greatest_solution) // self.block_size + if num_blocks > self._upper_bound_num_blocks: + logger.warning(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }") + num_blocks = self._upper_bound_num_blocks + max_batch_tokens = int(greatest_solution * m) + if max_batch_tokens > self._upper_bound_max_batch_tokens: + logger.warning(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }") + max_batch_tokens = self._upper_bound_max_batch_tokens + return num_blocks, max_batch_tokens + + def compute_max_batch_tokens( + self, + num_blocks: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + cache_memory = self.get_available_memory(max_memory_percent) + cache_size = num_blocks * self.block_size + # Compute numerator + num = cache_memory + num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + num -= 2 * self._input_dtype.itemsize + num -= cache_size * 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + # Compute denominator + denum = 8 * self._input_dtype.itemsize + cache_size * self._activation_dtype.itemsize + # Compute max batch tokens and return + return int(num / denum) + + def compute_num_blocks( + self, + max_batch_tokens: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + cache_memory = self.get_available_memory(max_memory_percent) + # Compute numerator + num = cache_memory + num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + num -= 8 * max_batch_tokens * self._input_dtype.itemsize + num -= 2 * self._input_dtype.itemsize + # Compute denominator + denum = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + denum += max_batch_tokens * self._activation_dtype.itemsize + # Compute cache size and return number of blocks + cache_size = int(num / denum) + return floor(cache_size / self.block_size) + + def compute_memory_footprint( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int, int]: + # Compute activation memory footprint + activation_memory_footprint = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) + activation_memory_footprint *= self._activation_safety_factor + # Compute cache memory footprint if num_blocks is provided + if num_blocks is not None: + cache_size = num_blocks * self.block_size + bytes_per_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + cache_memory_footprint = cache_size * bytes_per_token + else: + cache_memory_footprint = -1 + # Compute static tensors memory footprint if num_blocks and max_batch_tokens is provided + if num_blocks is not None and max_batch_tokens is not None: + static_memory_footprint = sum( + [ + 3 * max_batch_tokens * self._input_dtype.itemsize, # input_ids, position_ids, output_ids + max_batch_tokens * cache_size * self._activation_dtype.itemsize, # attention_mask + 2 * (max_batch_tokens + 1) * self._input_dtype.itemsize, # cumulative_seqlens_qk + 3 * max_batch_tokens * self._input_dtype.itemsize, # write_index, read_index, logits_indices + ] + ) + else: + static_memory_footprint = -1 + return activation_memory_footprint, cache_memory_footprint, static_memory_footprint diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching/continuous_api.py similarity index 53% rename from src/transformers/generation/continuous_batching.py rename to src/transformers/generation/continuous_batching/continuous_api.py index 81112864ba4a..81c8c6eb6828 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -15,654 +15,22 @@ # limitations under the License. import queue import threading -import time -from abc import ABC, abstractmethod -from collections import deque -from dataclasses import dataclass, field -from enum import Enum +from dataclasses import dataclass from functools import partial -from typing import Any, Optional, TypeVar, Union +from typing import Optional import torch -import torch.nn as nn from tokenizers.decoders import DecodeStream +from torch import nn from tqdm import tqdm -from ..configuration_utils import PretrainedConfig -from ..generation.configuration_utils import GenerationConfig -from ..tokenization_utils_fast import PreTrainedTokenizerFast -from ..utils.logging import logging -from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced -from .cb.memory_management import PagedAttentionMemoryHandler - - -class RequestStatus(Enum): - """Status of a generation request through its lifecycle.""" - - PENDING = "pending" - PREFILLING = "prefilling" - PREFILLING_SPLIT = "prefilling_split" - SPLIT_PENDING_REMAINDER = "split_pending_remainder" - DECODING = "decoding" - FINISHED = "finished" - FAILED = "failed" - - -logger = logging.getLogger(__name__) -# logger.setLevel(logging.INFO) - - -@dataclass -class GenerationOutput: - """Tracks the output of a generation request. - - Attributes: - request_id (str): The ID of the generation request. - prompt_ids (list[int]): The IDs of the prompt tokens. - generated_tokens (list[int]): The generated tokens. - logprobs (list[float]): The log probabilities of the generated tokens. - error (Optional[str]): Any error message associated with the request. When None, the request was successful. - status (RequestStatus): The status of the request. - created_time (float): The time the request was created. - next_token (Optional[int]): The next token to be generated. - """ - - request_id: str - prompt_ids: list[int] = field(default_factory=list) - generated_tokens: list[int] = field(default_factory=list) - logprobs: list[float] = field(default_factory=list) - error: Optional[str] = None - status: RequestStatus = RequestStatus.PENDING - created_time: float = field(default_factory=time.time) - next_token: Optional[int] = field(default_factory=int) - - -@dataclass -class RequestState: - """Tracks the state of a generation request through its lifecycle. - - Attributes: - request_id (str): The ID of the generation request. - full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. - prompt_ids (list[int] | None): The tokens IDs currently being processed. - remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). - static_outputs (list[int]): The generated tokens. - allocated_blocks (list[int]): The identifiers of the allocated blocks to the request. - position_offset (int): The current position in the sequence for position_ids. - status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT, - SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED - max_new_tokens (int): The maximum number of new tokens to generate. - eos_token_id (int): The ID of the end-of-sequence token. - created_time (float): The time the request was created. - error (Optional[str]): Any error message associated with the request. When None, has had no error yet. - next_token (Optional[str]): The next token to be generated. - """ - - # Required fields - request_id: str - full_prompt_ids: Optional[list[int]] = None # Full initial prompt - prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated) - remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process - static_outputs: list[int] = field(default_factory=list) # Generated tokens - allocated_blocks: list[int] = field(default_factory=list) # Block IDs allocated to the request - position_offset: int = 0 # Current position in the sequence for position_ids - status: RequestStatus = RequestStatus.PENDING # Status of the request - max_new_tokens: int = 20 # Maximum number of new tokens to generate - eos_token_id: int = -1 # ID of the end-of-sequence token - created_time: float = field(default_factory=time.time) # Time the request was created - error: Optional[str] = None # Error message if the request failed - next_token: Optional[str] = None # Next token to be generated - - def current_len(self) -> int: - """Get the current length of the sequence (prompt + generated tokens).""" - return self.position_offset - - def generated_len(self) -> int: - """Get the number of tokens generated so far.""" - return len(self.static_outputs) - - # TODO: this logic seems one token off, check it out - @traced - def update_with_token(self, token_id: int) -> bool: - """Update the request with a newly generated token and check for completion. - - Args: - token_id: The token ID to add to the output sequence - - Returns: - bool: True if the request is now complete, False otherwise - """ - # Only update if we're in decoding state - if self.status != RequestStatus.DECODING: - return False - - is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 - is_max_len = self.generated_len() >= self.max_new_tokens - - # Only add the token if we're not finishing due to max length - # (EOS tokens should still be added to the output) - if not (is_max_len and not is_eos): - self.static_outputs.extend([token_id]) - - if is_eos or is_max_len: - self.status = RequestStatus.FINISHED - return True - return False - - def __repr__(self): - return f"RequestState(\n\trequest_id={self.request_id},\n\tstatus={self.status},\n\tout_tokens={self.generated_len()},\n\tquery_length={len(self.prompt_ids)}, \n\tremaining_tokens={len(self.remaining_prompt_ids)}, \n\tkv_length={self.position_offset}\n\tfull_prompt_lenght={len(self.full_prompt_ids)},\n\tallocated_blocks={self.allocated_blocks},\n\tgenerated_tokens={self.static_outputs}\n)" - - def to_generation_output(self): - """Convert the request state to a GenerationOutput object.""" - return GenerationOutput( - request_id=self.request_id, - prompt_ids=self.full_prompt_ids, - status=self.status, - generated_tokens=self.static_outputs, - logprobs=[], - error=self.error, - next_token=self.next_token, - ) - - -T = TypeVar("T") -def getattr_no_none(obj: Any, attr: str, default: T) -> T: - x = getattr(obj, attr, None) - return x if x is not None else default - - -@attach_tracer() -class PagedAttentionCache: - def __init__( - self, - config: PretrainedConfig, - generation_config: GenerationConfig, - device: torch.device, - dtype: torch.dtype = torch.float16, - num_requests: int = 100, - layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - tp_size: Optional[int] = None, - ) -> None: - """Initialize a paged attention cache for efficient memory usage. - - Args: - config: Model configuration - generation_config: Generation configuration containing cache parameters - device: Device for the cache tensors - dtype: Data type for the cache tensors - layer_device_map: Optional mapping of layer indices to devices - initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size - """ - self.dtype = dtype - self.device = device - - # Extract model dimensions - self.num_key_value_heads: int = getattr_no_none(config, "num_key_value_heads", config.num_attention_heads) - self.head_dim: int = getattr_no_none(config, "head_dim", config.hidden_size // config.num_attention_heads) - - self.num_hidden_layers = config.num_hidden_layers - self.block_size = getattr(generation_config, "block_size", 32) - - # Handle TP - if tp_size is not None and tp_size > 1: - if self.num_key_value_heads % tp_size != 0: - raise ValueError( - f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." - ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - # self.num_key_value_heads //= tp_size # TODO: why is this commented out? - - # Infer number of blocks and max batch tokens - memory_handler = PagedAttentionMemoryHandler( - block_size=self.block_size, - head_dim=self.head_dim, - num_heads=self.num_key_value_heads, - num_layers=self.num_hidden_layers, - hidden_size=config.hidden_size, - vocab_size=config.vocab_size, - ) - num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( - num_blocks=getattr(generation_config, "num_blocks", None), - max_batch_tokens=getattr(generation_config, "max_batch_tokens", None), - max_memory_percent=getattr(generation_config, "max_memory", 0.9), - cache_dtype=self.dtype, - ) - - # Add the infered attributes to the class - self.num_blocks = num_blocks - self.max_batch_tokens = max_batch_tokens - logger.info( - f"After init, {self.num_blocks = }, {self.block_size = }, {self.max_batch_tokens = } " - f"{self.num_key_value_heads = }, {self.head_dim = }, {self.num_hidden_layers = }" - ) - - # Initialize the cache - self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - for idx in range(config.num_hidden_layers): - layer_device = layer_device_map[idx] if layer_device_map is not None else device - new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) - new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - # Block management data structures - self._free_blocks = deque(range(num_blocks)) - self._block_tables: dict[str, list[int]] = {} - - @traced - def allocate_blocks(self, n_blocks: int, request_id: str) -> list[int]: - """Allocates n_blocks for a given request_id.""" - if len(self._free_blocks) < n_blocks: - return False - - allocated = [] - for _ in range(n_blocks): - allocated.append(self._free_blocks.popleft()) - - if request_id not in self._block_tables: - self._block_tables[request_id] = [] - self._block_tables[request_id].extend(allocated) - return allocated - - @traced - def free_blocks(self, request_id: str) -> None: - """Frees all blocks associated with a request_id.""" - if request_id in self._block_tables: - blocks_to_free = self._block_tables.pop(request_id) - self._free_blocks.extend(blocks_to_free) - else: - logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}") - - def get_num_free_blocks(self) -> int: - """Returns the number of free blocks available.""" - return len(self._free_blocks) - - def get_block_table(self, request_id: str) -> list[int]: - """Returns the block table for a request.""" - return self._block_tables.get(request_id, []) - - @traced - def _get_physical_indices(self, state: RequestState, logical_indices: list[int]) -> list[int]: - """ - Maps logical sequence indices to physical cache indices using the block table, using PyTorch. - - Args: - request_id: The request ID. - logical_indices: A list of logical indices. - - Returns: - A list of physical indices. - - Raises: - ValueError: If no block table is found for the request ID. - IndexError: If a logical index maps to a block index that is out of bounds. - """ - request_id = state.request_id - block_table = self._block_tables.get(request_id) - if not block_table: - raise ValueError(f"No block table found for request {request_id}") - - block_size = self.block_size - physical_indices = [] - - for idx in logical_indices: - block_idx = idx // block_size - block_offset = idx % block_size - - if block_idx >= len(block_table): - raise IndexError( - f"Logical index {idx} maps to block index {block_idx} which is out of bounds " - f"for request {request_id}" - ) - - physical_block_num = block_table[block_idx] - physical_index = physical_block_num * block_size + block_offset - physical_indices.append(physical_index) - - return physical_indices - - @traced - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - read_index, - write_index, - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Reshape cache for easier indexing - total_slots = self.num_blocks * self.block_size - k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) - v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) - k_cache_flat[:, write_index, :] = key_states[0] - v_cache_flat[:, write_index, :] = value_states[0] - return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :] - - -class Scheduler(ABC): - """ - Abstract base class for scheduling requests in the continuous batch processor. - It is expected that cache allocation and scheduling logic will be implemented in subclasses. - """ - - def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False): - self.active_requests: dict[str, RequestState] = {} - self.waiting_requests: dict[str, RequestState] = {} - self.waiting_requests_order: deque[str] = deque() - self.cache = cache - self.retain_cache_on_finish = retain_cache_on_finish - - @abstractmethod - def add_waiting_request(self, state: RequestState): - """Add a request to the waiting list.""" - pass - - @abstractmethod - def schedule_batch(self, token_budget: int) -> list[RequestState]: - pass - - @traced - def has_pending_requests(self) -> bool: - """Check if there are requests ready to be processed.""" - return len(self.active_requests) or len(self.waiting_requests) - - @abstractmethod - def finish_request(self, request_id: str, evict_from_cache: bool = True): - """Finish processing a request and free its allocated blocks.""" - pass - - @traced - def get_active_request_static_outputs(self, request_id: str) -> list[int]: - if request_id in self.active_requests: - return self.active_requests[request_id].static_outputs - return [] - - -@attach_tracer() -class FIFOScheduler(Scheduler): - def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.1): - super().__init__(cache, retain_cache_on_finish) - self.safety_margin = safety_margin - - @traced - def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): - # 1. we check that the occupancy is less than the requested length - # 2. we allocate enough blocks to cover the requested length - current_len = state.current_len() - occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len - if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): - blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 - allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) - if not allocated: - return False - state.allocated_blocks.extend(allocated) - return True - - @traced(span_name="prepare_request") - def _prepare_request_for_processing( - self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] - ): - """Prepare a request for processing in the current batch.""" - request_tokens = ( - state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids - ) - if len(request_tokens) < token_budget: - # Can process the entire prompt/remainder - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING - state.prompt_ids = state.remaining_prompt_ids - state.remaining_prompt_ids = [] - else: - # Need to split the request - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING_SPLIT - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING_SPLIT - state.remaining_prompt_ids = request_tokens[token_budget:] - state.prompt_ids = request_tokens[:token_budget] - - @traced - def add_waiting_request(self, state: RequestState): - """Add a request to the waiting list.""" - if self.retain_cache_on_finish and state.request_id in self.active_requests: - old_state = self.active_requests.pop(state.request_id) - state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] - state.allocated_blocks = old_state.allocated_blocks - state.position_offset = old_state.position_offset - self.waiting_requests[state.request_id] = state - self.waiting_requests_order.append(state.request_id) - - @traced - def schedule_batch(self, token_budget: int) -> list[RequestState]: - priority_states: list[RequestState] = [] - second_priority_states: list[RequestState] = [] - scheduled_requests = [] - - for state in self.active_requests.values(): - if state.status == RequestStatus.DECODING: - priority_states.append(state) - if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - second_priority_states.append(state) - - # Add waiting requests to second priority - for req_id in self.waiting_requests_order: - second_priority_states.append(self.waiting_requests[req_id]) - - candidates = priority_states + second_priority_states - request_ids_to_remove_from_waiting = set() - safety_margins = self.safety_margin * self.cache.num_blocks - - for state in candidates: - - # If we are out the safety margin, we only accept decoding requests or the first prefill request - num_free_blocks = self.cache.get_num_free_blocks() - outside_safety_margin = num_free_blocks < safety_margins - if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING: - break - - self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) - request_len = len(state.prompt_ids) - if not self._allocate_blocks_if_needed(state, len(state.prompt_ids)): # don't schedule if we can't allocate blocks - if len(self.cache._free_blocks) == 0: - break - continue - - @traced - def _add_to_scheduled_requests(state: RequestState): - scheduled_requests.append(state) - - _add_to_scheduled_requests(state) - - token_budget -= request_len - - @traced - def _remove_from_waiting_requests(state: RequestState): - req_id = state.request_id - if req_id in self.waiting_requests: - del self.waiting_requests[req_id] - request_ids_to_remove_from_waiting.add(req_id) - - _remove_from_waiting_requests(state) - - if token_budget == 0: - break - - self.waiting_requests_order = deque( - [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] - ) - - return scheduled_requests - - @traced - def finish_request(self, request_id: str, evict_from_cache: bool = True): - if evict_from_cache: - self.cache.free_blocks(request_id) - if request_id in self.active_requests: - del self.active_requests[request_id] - - -@attach_tracer() -class PrefillFirstScheduler(Scheduler): - @traced - def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): - # 1. we check that the occupancy is less than the requested length - # 2. we allocate enough blocks to cover the requested length - current_len = state.current_len() - occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len - if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): - blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 - allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) - if not allocated: - return False - state.allocated_blocks.extend(allocated) - return True - - @traced(span_name="prepare_request") - def _prepare_request_for_processing( - self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] - ): - """Prepare a request for processing in the current batch.""" - request_tokens = ( - state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids - ) - if len(request_tokens) < token_budget: - # Can process the entire prompt/remainder - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING - state.prompt_ids = state.remaining_prompt_ids - state.remaining_prompt_ids = [] - else: - # Need to split the request - if state.status == RequestStatus.PENDING: - self.active_requests[state.request_id] = state - state.status = RequestStatus.PREFILLING_SPLIT - request_ids_to_remove_from_waiting.add(state.request_id) - elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - state.status = RequestStatus.PREFILLING_SPLIT - state.remaining_prompt_ids = request_tokens[token_budget:] - state.prompt_ids = request_tokens[:token_budget] - - @traced - def add_waiting_request(self, state: RequestState): - """Add a request to the waiting list.""" - if self.retain_cache_on_finish and state.request_id in self.active_requests: - old_state = self.active_requests.pop(state.request_id) - state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error? - state.allocated_blocks = old_state.allocated_blocks - state.position_offset = old_state.position_offset - self.waiting_requests[state.request_id] = state - self.waiting_requests_order.append(state.request_id) - - @traced - def schedule_batch(self, token_budget: int) -> list[RequestState]: - priority_states: list[RequestState] = [] - second_priority_states: list[RequestState] = [] - scheduled_requests = [] - - for state in self.active_requests.values(): - if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: - priority_states.append(state) - elif state.status == RequestStatus.DECODING: - second_priority_states.append(state) - - for req_id in self.waiting_requests_order: - second_priority_states.append(self.waiting_requests[req_id]) - - candidates = priority_states + second_priority_states - - request_ids_to_remove_from_waiting = set() - - for state in candidates: - self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) - request_len = len(state.prompt_ids) - if not self._allocate_blocks_if_needed( - state, len(state.prompt_ids) - ): # don't schedule if we can't allocate blocks - if len(self.cache._free_blocks) == 0: - break - continue - - @traced - def _add_to_scheduled_requests(state: RequestState): - scheduled_requests.append(state) - - _add_to_scheduled_requests(state) - - token_budget -= request_len - - @traced - def _remove_from_waiting_requests(state: RequestState): - req_id = state.request_id - if req_id in self.waiting_requests: - del self.waiting_requests[req_id] - request_ids_to_remove_from_waiting.add(req_id) - - _remove_from_waiting_requests(state) - - if token_budget == 0: - break - - self.waiting_requests_order = deque( - [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] - ) - - return scheduled_requests - - @traced - def finish_request(self, request_id: str, evict_from_cache: bool = True): - if evict_from_cache: - self.cache.free_blocks(request_id) - if request_id in self.active_requests: - del self.active_requests[request_id] - - -@traced(standalone=True) -def compute_optimal_blocks( - max_num_tokens, - block_size, - head_dim, - num_heads, - num_layers, - max_memory_percent=0.9, - num_blocks=None, - dtype=torch.float16, -): - device, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() - available_memory = int((total - max(allocated, reserved)) * max_memory_percent) - - dtype_size = torch.tensor([], dtype=dtype).element_size() - bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers - if num_blocks is not None: - # TODO - max_possible_concurrent_requests = num_blocks * bytes_per_token - # FIXME: forgot to add the inintial prompt length in the mix.... - max_possible_concurrent_requests = int( - available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4) - ) - if max_possible_concurrent_requests <= 0: - logger.warning("you are trying to generate a bit too many tokens") - max_possible_concurrent_requests = 32 - max_concurrent_tokens = min(64, max_possible_concurrent_requests) - # FIXME: Optimal means uses all memory - optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64) - return optimal_num_blocks, max_concurrent_tokens +from ...configuration_utils import PretrainedConfig +from ...generation.configuration_utils import GenerationConfig +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced +from .cache import PagedAttentionCache +from .core import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger +from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler @dataclass @@ -698,7 +66,7 @@ def __init__( scheduler: Scheduler, streaming: bool = False, manual_eviction: bool = False, - slice_inputs: bool = True, # TODO: remove this once parity is ensured + slice_inputs: bool = True, # TODO: remove this once parity is ensured ): """Initialize the continuous batch processor. @@ -745,7 +113,9 @@ def setup_static_tensors(self): self.tensor_metadata = tensor_metadata self.input_ids = torch.empty((1, T), **tensor_metadata) self.position_ids = torch.empty((1, T), **tensor_metadata) - self.attention_mask = torch.empty((1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device) + self.attention_mask = torch.empty( + (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device + ) self.cumulative_seqlens_q = torch.empty((T + 1,), **tensor_metadata) self.cumulative_seqlens_k = torch.empty((T + 1,), **tensor_metadata) self.write_index = torch.empty((T,), **tensor_metadata) @@ -773,8 +143,8 @@ def reset_static_tensors(self): self.input_ids[:, :t].zero_() self.position_ids[:, :t].zero_() self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min) - self.cumulative_seqlens_q[:t+1].zero_() - self.cumulative_seqlens_k[:t+1].zero_() + self.cumulative_seqlens_q[: t + 1].zero_() + self.cumulative_seqlens_k[: t + 1].zero_() self.write_index[:t].fill_(-1) self.read_index[:c].fill_(-1) self.logits_indices[:t].fill_(-1) @@ -782,10 +152,9 @@ def reset_static_tensors(self): self.max_seqlen_k = 0 self.output_ids[:, :t].fill_(-1) - def get_model_kwargs(self) -> PagedAttentionArgs: """Get model keyword arguments for the current batch.""" - # Compute the slice to return + # Compute the slice to return t = self.actual_tokens if self.slice_inputs else self.write_index.size(0) c = self.cache_used if self.slice_inputs else self.read_index.size(0) # Return the tensors @@ -1035,12 +404,6 @@ def fail_all_requests(self, error): self.scheduler.waiting_requests_order.clear() -SCHEDULER_MAPPING = { - "fifo": FIFOScheduler, - "prefill_first": PrefillFirstScheduler, -} - - # Manager Class (User Interface) @attach_tracer() class ContinuousBatchingManager: @@ -1263,7 +626,7 @@ def _run_generation_loop(self): scheduler = None if hasattr(self.generation_config, "scheduler"): - scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler) + scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler, None) if scheduler is None: logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.") scheduler = FIFOScheduler @@ -1301,7 +664,7 @@ def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor): if torch.cuda.is_available(): torch.cuda.synchronize() batch_processor.prepare_next_batch() - device, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() + device, total, reserved, allocated = get_device_and_memory_breakdown() logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") if torch.cuda.is_available() and self.use_cuda_graph: if self.current_batch == 0: @@ -1416,7 +779,7 @@ def generate_batch( """ if not inputs: return [] - if logger.getEffectiveLevel() <= logging.INFO: + if logger.getEffectiveLevel() <= logger.INFO: logger.warning("Progress bar is disabled when logger level is less than INFO") progress_bar = False diff --git a/src/transformers/generation/continuous_batching/core.py b/src/transformers/generation/continuous_batching/core.py new file mode 100644 index 000000000000..3f476ef1b99a --- /dev/null +++ b/src/transformers/generation/continuous_batching/core.py @@ -0,0 +1,172 @@ +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +import torch + +from ...utils.logging import logging +from ...utils.metrics import traced + + +# We centralize the logger here to coordinate between logging and progress bar +logger = logging.getLogger("ContinuousBatchingLogger") + + +@staticmethod +def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: + if torch.cuda.is_available(): + device = torch.device("cuda") + torch.cuda.empty_cache() + torch.cuda.synchronize() + total_memory = torch.cuda.get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved(device) + allocated_memory = torch.cuda.memory_allocated(device) + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + # MPS memory reporting (PyTorch 2.0+) + total_memory = torch.mps.driver_allocated_memory() + allocated_memory = total_memory - torch.mps.recommended_max_memory() + reserved_memory = 0 # MPS does not track reserved separately + else: + device = torch.device("cpu") + total_memory = None + reserved_memory = 0 + allocated_memory = 0 + return device, total_memory, reserved_memory, allocated_memory + + +class RequestStatus(Enum): + """Status of a generation request through its lifecycle.""" + + PENDING = "pending" + PREFILLING = "prefilling" + PREFILLING_SPLIT = "prefilling_split" + SPLIT_PENDING_REMAINDER = "split_pending_remainder" + DECODING = "decoding" + FINISHED = "finished" + FAILED = "failed" + + +@dataclass +class GenerationOutput: + """Tracks the output of a generation request. + + Attributes: + request_id (str): The ID of the generation request. + prompt_ids (list[int]): The IDs of the prompt tokens. + generated_tokens (list[int]): The generated tokens. + logprobs (list[float]): The log probabilities of the generated tokens. + error (Optional[str]): Any error message associated with the request. When None, the request was successful. + status (RequestStatus): The status of the request. + created_time (float): The time the request was created. + next_token (Optional[int]): The next token to be generated. + """ + + request_id: str + prompt_ids: list[int] = field(default_factory=list) + generated_tokens: list[int] = field(default_factory=list) + logprobs: list[float] = field(default_factory=list) + error: Optional[str] = None + status: RequestStatus = RequestStatus.PENDING + created_time: float = field(default_factory=time.time) + next_token: Optional[int] = field(default_factory=int) + + +@dataclass +class RequestState: + """Tracks the state of a generation request through its lifecycle. + + Attributes: + request_id (str): The ID of the generation request. + full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. + prompt_ids (list[int] | None): The tokens IDs currently being processed. + remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). + static_outputs (list[int]): The generated tokens. + allocated_blocks (list[int]): The identifiers of the allocated blocks to the request. + position_offset (int): The current position in the sequence for position_ids. + status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT, + SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED + max_new_tokens (int): The maximum number of new tokens to generate. + eos_token_id (int): The ID of the end-of-sequence token. + created_time (float): The time the request was created. + error (Optional[str]): Any error message associated with the request. When None, has had no error yet. + next_token (Optional[str]): The next token to be generated. + """ + + # Required fields + request_id: str + full_prompt_ids: Optional[list[int]] = None # Full initial prompt + prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated) + remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process + static_outputs: list[int] = field(default_factory=list) # Generated tokens + allocated_blocks: list[int] = field(default_factory=list) # Block IDs allocated to the request + position_offset: int = 0 # Current position in the sequence for position_ids + status: RequestStatus = RequestStatus.PENDING # Status of the request + max_new_tokens: int = 20 # Maximum number of new tokens to generate + eos_token_id: int = -1 # ID of the end-of-sequence token + created_time: float = field(default_factory=time.time) # Time the request was created + error: Optional[str] = None # Error message if the request failed + next_token: Optional[str] = None # Next token to be generated + + def current_len(self) -> int: + """Get the current length of the sequence (prompt + generated tokens).""" + return self.position_offset + + def generated_len(self) -> int: + """Get the number of tokens generated so far.""" + return len(self.static_outputs) + + # TODO: this logic seems one token off, check it out + @traced + def update_with_token(self, token_id: int) -> bool: + """Update the request with a newly generated token and check for completion. + + Args: + token_id: The token ID to add to the output sequence + + Returns: + bool: True if the request is now complete, False otherwise + """ + # Only update if we're in decoding state + if self.status != RequestStatus.DECODING: + return False + + is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 + is_max_len = self.generated_len() >= self.max_new_tokens + + # Only add the token if we're not finishing due to max length + # (EOS tokens should still be added to the output) + if not (is_max_len and not is_eos): + self.static_outputs.extend([token_id]) + + if is_eos or is_max_len: + self.status = RequestStatus.FINISHED + return True + return False + + def __repr__(self): + msg = [ + f"request_id={self.request_id}", + f"status={self.status}", + f"out_tokens={self.generated_len()}", + f"query_length={len(self.prompt_ids)}", + f"remaining_tokens={len(self.remaining_prompt_ids)}", + f"kv_length={self.position_offset}", + f"full_prompt_lenght={len(self.full_prompt_ids)}", + f"allocated_blocks={self.allocated_blocks}", + f"generated_tokens={self.static_outputs}", + ] + return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)" + + def to_generation_output(self): + """Convert the request state to a GenerationOutput object.""" + return GenerationOutput( + request_id=self.request_id, + prompt_ids=self.full_prompt_ids, + status=self.status, + generated_tokens=self.static_outputs, + logprobs=[], + error=self.error, + next_token=self.next_token, + ) diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py new file mode 100644 index 000000000000..ad236b173584 --- /dev/null +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -0,0 +1,300 @@ +from abc import ABC, abstractmethod +from collections import deque + +from ...utils.metrics import attach_tracer, traced +from .cache import PagedAttentionCache +from .core import RequestState, RequestStatus + + +class Scheduler(ABC): + """ + Abstract base class for scheduling requests in the continuous batch processor. + It is expected that cache allocation and scheduling logic will be implemented in subclasses. + """ + + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False): + self.active_requests: dict[str, RequestState] = {} + self.waiting_requests: dict[str, RequestState] = {} + self.waiting_requests_order: deque[str] = deque() + self.cache = cache + self.retain_cache_on_finish = retain_cache_on_finish + + @abstractmethod + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + pass + + @abstractmethod + def schedule_batch(self, token_budget: int) -> list[RequestState]: + pass + + @traced + def has_pending_requests(self) -> bool: + """Check if there are requests ready to be processed.""" + return len(self.active_requests) or len(self.waiting_requests) + + @abstractmethod + def finish_request(self, request_id: str, evict_from_cache: bool = True): + """Finish processing a request and free its allocated blocks.""" + pass + + @traced + def get_active_request_static_outputs(self, request_id: str) -> list[int]: + if request_id in self.active_requests: + return self.active_requests[request_id].static_outputs + return [] + + +@attach_tracer() +class FIFOScheduler(Scheduler): + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.1): + super().__init__(cache, retain_cache_on_finish) + self.safety_margin = safety_margin + + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + if self.retain_cache_on_finish and state.request_id in self.active_requests: + old_state = self.active_requests.pop(state.request_id) + state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] + state.allocated_blocks = old_state.allocated_blocks + state.position_offset = old_state.position_offset + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> list[RequestState]: + priority_states: list[RequestState] = [] + second_priority_states: list[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.DECODING: + priority_states.append(state) + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + second_priority_states.append(state) + + # Add waiting requests to second priority + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + request_ids_to_remove_from_waiting = set() + safety_margins = self.safety_margin * self.cache.num_blocks + + for state in candidates: + # If we are out the safety margin, we only accept decoding requests or the first prefill request + num_free_blocks = self.cache.get_num_free_blocks() + outside_safety_margin = num_free_blocks < safety_margins + if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING: + break + + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, request_id: str, evict_from_cache: bool = True): + if evict_from_cache: + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +@attach_tracer() +class PrefillFirstScheduler(Scheduler): + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + if self.retain_cache_on_finish and state.request_id in self.active_requests: + old_state = self.active_requests.pop(state.request_id) + state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error? + state.allocated_blocks = old_state.allocated_blocks + state.position_offset = old_state.position_offset + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> list[RequestState]: + priority_states: list[RequestState] = [] + second_priority_states: list[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + priority_states.append(state) + elif state.status == RequestStatus.DECODING: + second_priority_states.append(state) + + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + + request_ids_to_remove_from_waiting = set() + + for state in candidates: + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, request_id: str, evict_from_cache: bool = True): + if evict_from_cache: + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +SCHEDULER_MAPPING = { + "fifo": FIFOScheduler, + "prefill_first": PrefillFirstScheduler, +} From 3cffe20e2b736ba931ed573af233f495481a4092 Mon Sep 17 00:00:00 2001 From: remi-or Date: Fri, 22 Aug 2025 12:52:29 +0000 Subject: [PATCH 09/26] Better saving of cb example --- examples/pytorch/continuous_batching.py | 4 +++- .../generation/continuous_batching/continuous_api.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 2bfcd0d3ba50..42168fa44494 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -1,5 +1,6 @@ import argparse import json +import os import time from typing import Optional @@ -179,8 +180,9 @@ def batch_generate( # If no output file is provided, we pick a name based on the args if args.output_file is None: + os.makedirs("runs/cb", exist_ok=True) args.output_file = ( - f"cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json" + f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json" ) # Run warmup batch generation diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 81c8c6eb6828..c930d8049e91 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -27,6 +27,7 @@ from ...configuration_utils import PretrainedConfig from ...generation.configuration_utils import GenerationConfig from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils.logging import logging from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced from .cache import PagedAttentionCache from .core import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger @@ -779,7 +780,7 @@ def generate_batch( """ if not inputs: return [] - if logger.getEffectiveLevel() <= logger.INFO: + if logger.getEffectiveLevel() <= logging.INFO: logger.warning("Progress bar is disabled when logger level is less than INFO") progress_bar = False From 7cd70ac1fcc305e8787d2e4d0384ca718ae1bc1a Mon Sep 17 00:00:00 2001 From: remi-or Date: Fri, 22 Aug 2025 13:02:36 +0000 Subject: [PATCH 10/26] Fix --- src/transformers/generation/continuous_batching/cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index c8d263281935..49859cbaad77 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -21,7 +21,7 @@ from ...configuration_utils import PretrainedConfig from ...generation.configuration_utils import GenerationConfig from ...utils.metrics import attach_tracer, traced -from .core import RequestState, logger +from .core import RequestState, logger, get_device_and_memory_breakdown T = TypeVar("T") @@ -232,7 +232,7 @@ def __init__( @staticmethod def get_available_memory(max_memory_percent: float = 1.0) -> int: - _, total, reserved, allocated = PagedAttentionMemoryHandler.get_device_and_memory_breakdown() + _, total, reserved, allocated = get_device_and_memory_breakdown() available_memory = total - max(allocated, reserved) available_memory = int(available_memory * max_memory_percent) return available_memory From 2933099b7b2a1c1085c381dd9a636b161de4cb43 Mon Sep 17 00:00:00 2001 From: remi-or Date: Fri, 22 Aug 2025 16:08:22 +0000 Subject: [PATCH 11/26] Fixes and QOL --- examples/pytorch/continuous_batching.py | 4 +++- src/transformers/generation/continuous_batching/cache.py | 7 ++----- .../generation/continuous_batching/scheduler.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 42168fa44494..555cf5b1fd84 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -136,6 +136,7 @@ def batch_generate( ) parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable parser.add_argument("--use-cuda-graph", action="store_true", default=False) + parser.add_argument("--compile", action="store_true", default=False) parser.add_argument("--samples", type=int, default=500) parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display") @@ -155,7 +156,8 @@ def batch_generate( torch_dtype=torch.bfloat16, ) model = model.cuda().eval() - # model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + if args.compile: + model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") # Prepare tokenizer and dataset tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 49859cbaad77..412eb0b4a8ae 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -21,7 +21,7 @@ from ...configuration_utils import PretrainedConfig from ...generation.configuration_utils import GenerationConfig from ...utils.metrics import attach_tracer, traced -from .core import RequestState, logger, get_device_and_memory_breakdown +from .core import RequestState, get_device_and_memory_breakdown, logger T = TypeVar("T") @@ -92,10 +92,7 @@ def __init__( # Add the infered attributes to the class self.num_blocks = num_blocks self.max_batch_tokens = max_batch_tokens - logger.info( - f"After init, {self.num_blocks = }, {self.block_size = }, {self.max_batch_tokens = } " - f"{self.num_key_value_heads = }, {self.head_dim = }, {self.num_hidden_layers = }" - ) + logger.warning(f"PagedAttentionCache initialized with {self.num_blocks = } and {self.max_batch_tokens = } ") # Initialize the cache self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py index ad236b173584..cc5cb538d087 100644 --- a/src/transformers/generation/continuous_batching/scheduler.py +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -47,7 +47,7 @@ def get_active_request_static_outputs(self, request_id: str) -> list[int]: @attach_tracer() class FIFOScheduler(Scheduler): - def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.1): + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.0): super().__init__(cache, retain_cache_on_finish) self.safety_margin = safety_margin From bfcf6117b0885d6ff87bde132cbad8d455e52b84 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 12:05:01 +0000 Subject: [PATCH 12/26] Mor einformations about metrics --- examples/metrics-monitoring/README.md | 37 +++++++++++++++++++++++++ examples/pytorch/continuous_batching.py | 33 +++++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/examples/metrics-monitoring/README.md b/examples/metrics-monitoring/README.md index 64ef1160c66b..62150f3a10b7 100644 --- a/examples/metrics-monitoring/README.md +++ b/examples/metrics-monitoring/README.md @@ -2,3 +2,40 @@ ## Continuous Batching Metrics in Transformers +To setup metric monitoring with continuous batching, you will want to have tempo and prometheus running. + +For this, we provide a docker compose image in `examples/metrics-monitoring`. + +To run it: + +```sh +cd examples/metrics-monitoring +docker compose up +``` + +Then, in your srcipt running CB, you will need to create a MeterProvider and TracerProvider as follows: + +```py +from opentelemetry import metrics, trace +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +resource = Resource.create({"service.name": "transformers"}) + +metrics_exporter = PeriodicExportingMetricReader( + OTLPMetricExporter(endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var + export_interval_millis=1000 +) +meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter]) +metrics.set_meter_provider(meter_provider) + +trace_exporter = OTLPSpanExporter(endpoint="http://localhost:4318/v1/traces") # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var +tracer_provider = TracerProvider(resource=resource) +tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) +trace.set_tracer_provider(tracer_provider) +``` diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 555cf5b1fd84..930f1ab11ec2 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -11,7 +11,7 @@ from transformers.generation import GenerationConfig -MODEL_ID = "meta-llama/Llama-3.2-3b-Instruct" +MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" def generate_simple( @@ -44,6 +44,33 @@ def generate_simple( return decoded_outputs +def setup_metrics(): + try: + from opentelemetry import metrics, trace + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + + resource = Resource.create({"service.name": "transformers"}) + metrics_exporter = PeriodicExportingMetricReader( + OTLPMetricExporter(endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var + export_interval_millis=1000 + ) + meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter]) + metrics.set_meter_provider(meter_provider) + trace_exporter = OTLPSpanExporter(endpoint="http://localhost:4318/v1/traces") # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + trace.set_tracer_provider(tracer_provider) + except Exception as e: + print(f"Error setting up metrics: {e}") + + def batch_generate( model: AutoModelForCausalLM, simple_batch_inputs: list, @@ -142,8 +169,12 @@ def batch_generate( parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display") parser.add_argument("--output-file", type=str, default=None) parser.add_argument("--compare", action="store_true", default=False) + parser.add_argument("--metrics", action="store_true", default=False) args = parser.parse_args() + if args.metrics: + setup_metrics() + # Set matmul precision if args.matmul_precision != "none": torch.set_float32_matmul_precision(args.matmul_precision) From f000b17dbe5cd61653f2e8fb53dfa6421e26a168 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 12:05:18 +0000 Subject: [PATCH 13/26] Further logging --- .../generation/continuous_batching/cache.py | 20 ++++++++----- .../continuous_batching/continuous_api.py | 12 ++++---- .../generation/continuous_batching/core.py | 28 +++++++++++++++++-- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 412eb0b4a8ae..c3544222909b 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -208,8 +208,8 @@ class PagedAttentionMemoryHandler: _activation_dtype = torch.bfloat16 _activation_safety_factor = 2 _input_dtype = torch.int32 - _upper_bound_max_batch_tokens = 2048 - _upper_bound_num_blocks = 16384 + _upper_bound_max_batch_tokens = 256 + _upper_bound_num_blocks = 4096 def __init__( self, @@ -271,14 +271,20 @@ def compute_num_blocks_and_max_batch_tokens( m: float = 0.1, ) -> tuple[int, int]: cache_memory = self.get_available_memory(max_memory_percent) + logger.info(f"Cache memory: {cache_memory}") + + # Compute memory footprints # TODO: check and explain better + mem_per_activation_token = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + mem_per_input_token = 8 * m * self._input_dtype.itemsize + logger.info(f"Memory per activation token: {mem_per_activation_token}") + logger.info(f"Memory per cache token: {mem_per_cache_token}") + logger.info(f"Memory per input token: {mem_per_input_token}") # Compute second-degree polynomial coefficients a = m * self._activation_dtype.itemsize - b = 8 * m * self._input_dtype.itemsize - b += 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize - c = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor - c += 2 * self._input_dtype.itemsize - c -= cache_memory + b = mem_per_input_token + mem_per_cache_token + c = mem_per_activation_token + 2 * self._input_dtype.itemsize - cache_memory # Compute discriminant and greatest solution discriminant = b**2 - 4 * a * c diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index c930d8049e91..e2beb2c5f5b5 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -262,8 +262,10 @@ def prepare_next_batch(self): self.max_seqlen_k = max(self.max_seqlen_k, key_length) state.position_offset += query_length - logger.info( - f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}" + logger.debug( + f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, " + f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. " + f"cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}" ) self._build_tensors( input_ids, @@ -666,7 +668,7 @@ def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor): torch.cuda.synchronize() batch_processor.prepare_next_batch() device, total, reserved, allocated = get_device_and_memory_breakdown() - logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") + logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") if torch.cuda.is_available() and self.use_cuda_graph: if self.current_batch == 0: self.warmup(batch_processor) @@ -780,8 +782,8 @@ def generate_batch( """ if not inputs: return [] - if logger.getEffectiveLevel() <= logging.INFO: - logger.warning("Progress bar is disabled when logger level is less than INFO") + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.warning("Progress bar is disabled when logger level is less than DEBUG") progress_bar = False # Initialize manager with the batch inputs diff --git a/src/transformers/generation/continuous_batching/core.py b/src/transformers/generation/continuous_batching/core.py index 3f476ef1b99a..cab98c755e54 100644 --- a/src/transformers/generation/continuous_batching/core.py +++ b/src/transformers/generation/continuous_batching/core.py @@ -11,6 +11,7 @@ # We centralize the logger here to coordinate between logging and progress bar logger = logging.getLogger("ContinuousBatchingLogger") +logger.setLevel(logging.INFO) @staticmethod @@ -102,12 +103,35 @@ class RequestState: static_outputs: list[int] = field(default_factory=list) # Generated tokens allocated_blocks: list[int] = field(default_factory=list) # Block IDs allocated to the request position_offset: int = 0 # Current position in the sequence for position_ids - status: RequestStatus = RequestStatus.PENDING # Status of the request + _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property max_new_tokens: int = 20 # Maximum number of new tokens to generate eos_token_id: int = -1 # ID of the end-of-sequence token created_time: float = field(default_factory=time.time) # Time the request was created error: Optional[str] = None # Error message if the request failed next_token: Optional[str] = None # Next token to be generated + lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished) + + @property + def status(self) -> RequestStatus: + return self._status + + @status.setter + def status(self, value: RequestStatus): + if self._status == RequestStatus.PENDING: + self.lifespan = (time.time(), -1) + elif value == RequestStatus.FINISHED: + self.lifespan = (self.lifespan[0], time.time()) + self.log_end_of_request() + self._status = value + + def log_end_of_request(self): + prefill_len = len(self.full_prompt_ids) + decode_len = self.generated_len() + start_time = self.lifespan[0] - self.created_time + end_time = self.lifespan[1] - self.created_time + logger.info( + f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }" + ) def current_len(self) -> int: """Get the current length of the sequence (prompt + generated tokens).""" @@ -148,7 +172,7 @@ def update_with_token(self, token_id: int) -> bool: def __repr__(self): msg = [ f"request_id={self.request_id}", - f"status={self.status}", + f"status={self._status}", f"out_tokens={self.generated_len()}", f"query_length={len(self.prompt_ids)}", f"remaining_tokens={len(self.remaining_prompt_ids)}", From 042e87dddc566b7d53f9b09c63c5f50053935e99 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 12:07:26 +0000 Subject: [PATCH 14/26] Style --- examples/pytorch/continuous_batching.py | 15 ++++++++------- .../generation/continuous_batching/cache.py | 4 +++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 930f1ab11ec2..b6a14bc4ebc4 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -55,15 +55,18 @@ def setup_metrics(): from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor - resource = Resource.create({"service.name": "transformers"}) metrics_exporter = PeriodicExportingMetricReader( - OTLPMetricExporter(endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var - export_interval_millis=1000 + OTLPMetricExporter( + endpoint="http://localhost:9090/api/v1/otlp/v1/metrics" + ), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var + export_interval_millis=1000, ) meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter]) metrics.set_meter_provider(meter_provider) - trace_exporter = OTLPSpanExporter(endpoint="http://localhost:4318/v1/traces") # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var + trace_exporter = OTLPSpanExporter( + endpoint="http://localhost:4318/v1/traces" + ) # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var tracer_provider = TracerProvider(resource=resource) tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) trace.set_tracer_provider(tracer_provider) @@ -214,9 +217,7 @@ def batch_generate( # If no output file is provided, we pick a name based on the args if args.output_file is None: os.makedirs("runs/cb", exist_ok=True) - args.output_file = ( - f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json" - ) + args.output_file = f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json" # Run warmup batch generation batch_generate( diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index c3544222909b..9257e4ce8232 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -274,7 +274,9 @@ def compute_num_blocks_and_max_batch_tokens( logger.info(f"Cache memory: {cache_memory}") # Compute memory footprints # TODO: check and explain better - mem_per_activation_token = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + mem_per_activation_token = ( + self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + ) mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize mem_per_input_token = 8 * m * self._input_dtype.itemsize logger.info(f"Memory per activation token: {mem_per_activation_token}") From 604fe6e52b68b4834e7dae8183b5db81940f1970 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 12:09:48 +0000 Subject: [PATCH 15/26] Licenses --- examples/pytorch/continuous_batching.py | 14 ++++++++++++++ .../generation/continuous_batching/__init__.py | 14 ++++++++++++++ .../generation/continuous_batching/core.py | 14 ++++++++++++++ .../generation/continuous_batching/scheduler.py | 14 ++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index b6a14bc4ebc4..27d3cc54c378 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import json import os diff --git a/src/transformers/generation/continuous_batching/__init__.py b/src/transformers/generation/continuous_batching/__init__.py index 7e7f3fb7f925..1d939594d64b 100644 --- a/src/transformers/generation/continuous_batching/__init__.py +++ b/src/transformers/generation/continuous_batching/__init__.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .cache import PagedAttentionCache from .continuous_api import ContinuousBatchingManager, ContinuousMixin from .core import RequestState, RequestStatus diff --git a/src/transformers/generation/continuous_batching/core.py b/src/transformers/generation/continuous_batching/core.py index cab98c755e54..f2c3a9eda455 100644 --- a/src/transformers/generation/continuous_batching/core.py +++ b/src/transformers/generation/continuous_batching/core.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import time from dataclasses import dataclass, field from enum import Enum diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py index cc5cb538d087..b08bd812fe33 100644 --- a/src/transformers/generation/continuous_batching/scheduler.py +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC, abstractmethod from collections import deque From ef6354779f0a7de4337b20ebf6310acb74a35b07 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 12:17:09 +0000 Subject: [PATCH 16/26] Removed some comments --- examples/pytorch/continuous_batching.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 27d3cc54c378..628704675241 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -253,19 +253,5 @@ def batch_generate( expected_outputs=expected_outputs, ) - -# python examples/pytorch/continuous_batching.py --attn sdpa_paged --matmul-precision none --samples 50 --displayed 0 -# Using calculated self.num_blocks = 4096, self.block_size = 32, self.max_batch_tokens = 2048 -# CB generation took: 18.80 seconds for 13775 tokens. 732.74tok/s - - -# python examples/pytorch/continuous_batching.py --attn sdpa_paged --matmul-precision none --samples 100 --displayed 1 -# Setting up static tensors with T = 4096, max_token_budget = 524288, 139538202624 bytes available -# CB generation took: 29.53 seconds for 26384 tokens. 893.41tok/s - -# Without changes to continuous_batching.py -# Using calculated num_blocks=369, block_size=32, max concurrent requests 23 -# CB generation took: 79.58 seconds for 25813 tokens. 324.38tok/s - - +# Example usage: # python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json From d403b02f894f2e7e0c1459805087df1bc1c4efb3 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 13:45:25 +0000 Subject: [PATCH 17/26] Add a slice input flag --- examples/pytorch/continuous_batching.py | 13 ++++++++++++- .../continuous_batching/continuous_api.py | 8 +++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 628704675241..40cc65dc3f09 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -96,6 +96,7 @@ def batch_generate( displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs output_file: Optional[str] = None, expected_outputs: Optional[list[str]] = None, + slice_inputs: bool = True, ) -> tuple[float, float]: # Actual batch generation if displayed_samples >= 0: @@ -104,6 +105,7 @@ def batch_generate( batch_outputs = model.generate_batch( inputs=simple_batch_inputs, generation_config=generation_config, + slice_inputs=slice_inputs, # TODO: move this to the generation config ) end_time_simple = time.time() if displayed_samples >= 0: @@ -179,6 +181,7 @@ def batch_generate( "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" ) parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable + parser.add_argument("--slice-inputs", action="store_true", default=False) parser.add_argument("--use-cuda-graph", action="store_true", default=False) parser.add_argument("--compile", action="store_true", default=False) @@ -189,10 +192,11 @@ def batch_generate( parser.add_argument("--metrics", action="store_true", default=False) args = parser.parse_args() + # If turned on, we setup metrics if args.metrics: setup_metrics() - # Set matmul precision + # Set matmul precision if not none if args.matmul_precision != "none": torch.set_float32_matmul_precision(args.matmul_precision) @@ -204,8 +208,13 @@ def batch_generate( torch_dtype=torch.bfloat16, ) model = model.cuda().eval() + + # If turned on, we compile the model if args.compile: model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + if args.slice_inputs: + assert not args.compile, "Slicing inputs requires is not the model to be compiled" + assert not args.use_cuda_graph, "Slicing inputs is not compatible with cuda graphs" # Prepare tokenizer and dataset tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") @@ -240,6 +249,7 @@ def batch_generate( generation_config, tokenizer, displayed_samples=-1, + slice_inputs=args.slice_inputs, ) # Run batch generation @@ -251,6 +261,7 @@ def batch_generate( displayed_samples=args.displayed, output_file=args.output_file, expected_outputs=expected_outputs, + slice_inputs=args.slice_inputs, ) # Example usage: diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index e2beb2c5f5b5..9e629f30f316 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -423,6 +423,7 @@ def __init__( manual_eviction: bool = False, max_queue_size=0, streaming: bool = True, + slice_inputs: bool = True, ): """Initialize the continuous batching manager. @@ -451,6 +452,7 @@ def __init__( self.manual_eviction = manual_eviction self.batch_processor: Optional[ContinuousBatchProcessor] = None self.decode_stream = DecodeStream(skip_special_tokens=True) + self.slice_inputs = slice_inputs @traced def start(self): @@ -649,6 +651,7 @@ def _run_generation_loop(self): scheduler(paged_attention_cache, self.manual_eviction), self.streaming, self.manual_eviction, + slice_inputs=self.slice_inputs, ) self.batch_processor = batch_processor self.current_batch = 0 @@ -728,6 +731,7 @@ def init_continuous_batching( manual_eviction: bool = False, max_queue_size: int = 0, streaming: bool = False, + slice_inputs: bool = True, ) -> ContinuousBatchingManager: """Initialize a manager for continuous batching inference. @@ -757,6 +761,7 @@ def init_continuous_batching( manual_eviction=manual_eviction, max_queue_size=max_queue_size, streaming=streaming, + slice_inputs=slice_inputs, ) @traced @@ -766,6 +771,7 @@ def generate_batch( inputs: list[list[int]], generation_config: Optional[GenerationConfig] = None, progress_bar: bool = True, + slice_inputs: bool = True, **kwargs, ) -> list[list[int]]: """Generate sequences for a batch of prompts using continuous batching. @@ -787,7 +793,7 @@ def generate_batch( progress_bar = False # Initialize manager with the batch inputs - manager = self.init_continuous_batching(generation_config=generation_config) + manager = self.init_continuous_batching(generation_config=generation_config, slice_inputs=slice_inputs) manager.start() results = {} num_requests = len(inputs) From 023774fd03ad32a43a6a6bba56f5babf4cacabd4 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 14:04:36 +0000 Subject: [PATCH 18/26] Fix in example --- examples/pytorch/continuous_batching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 40cc65dc3f09..9425eb26eb86 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -240,7 +240,8 @@ def batch_generate( # If no output file is provided, we pick a name based on the args if args.output_file is None: os.makedirs("runs/cb", exist_ok=True) - args.output_file = f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json" + attn = args.attn.replace("|", "_").replace("/", "_") + args.output_file = f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json" # Run warmup batch generation batch_generate( From c327f08d2ac027fc8b6047aab8e8561fa23b03f9 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 14:24:52 +0000 Subject: [PATCH 19/26] Added back some open-telemetry deps --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3b67610db313..79bb0f9ef0d8 100644 --- a/setup.py +++ b/setup.py @@ -445,7 +445,7 @@ def run(self): extras["benchmark"] = deps_list("optimum-benchmark") # OpenTelemetry dependencies for metrics collection in continuous batching -extras["open-telemetry"] = deps_list("opentelemetry-api") +extras["open-telemetry"] = deps_list("opentelemetry-api") + ["opentelemetry-exporter-otlp", "opentelemetry-sdk"] # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ From 173b497ac25b674fa48e8f55b13dad5a08d273ba Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 14:45:04 +0000 Subject: [PATCH 20/26] Removed some aux function --- .../generation/continuous_batching/cache.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 9257e4ce8232..bb901b38101d 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -24,14 +24,6 @@ from .core import RequestState, get_device_and_memory_breakdown, logger -T = TypeVar("T") - - -def getattr_no_none(obj: Any, attr: str, default: T) -> T: - x = getattr(obj, attr, None) - return x if x is not None else default - - @attach_tracer() class PagedAttentionCache: def __init__( @@ -58,8 +50,10 @@ def __init__( self.device = device # Extract model dimensions - self.num_key_value_heads: int = getattr_no_none(config, "num_key_value_heads", config.num_attention_heads) - self.head_dim: int = getattr_no_none(config, "head_dim", config.hidden_size // config.num_attention_heads) + kv_heads = getattr(config, "num_key_value_heads", None) + self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads + head_dim = getattr(config, "head_dim", None) + self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads self.num_hidden_layers = config.num_hidden_layers self.block_size = getattr(generation_config, "block_size", 32) From fff2ee8a8f260fce3fc4266ca68a8a3c9a8494cd Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 15:22:36 +0000 Subject: [PATCH 21/26] Added FA2 option to example script --- examples/pytorch/continuous_batching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 9425eb26eb86..731916dbd1b3 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -34,6 +34,7 @@ def generate_simple( attn_implementation = { "sdpa_paged": "sdpa", "eager_paged": "eager", + "flash_paged": "flash_attention_2", }[attn_implementation] model = ( @@ -205,7 +206,6 @@ def batch_generate( MODEL_ID, attn_implementation=args.attn, dtype=torch.bfloat16, - torch_dtype=torch.bfloat16, ) model = model.cuda().eval() From 7353aef1fcec94081498d4e2b226346c5ed4e525 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 16:17:22 +0000 Subject: [PATCH 22/26] Fixed math (all of it) --- .../generation/continuous_batching/cache.py | 51 +++++++++++++++---- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index bb901b38101d..cdc228075c22 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -235,6 +235,17 @@ def infer_num_blocks_and_max_batch_tokens( max_memory_percent: float = 0.9, cache_dtype: torch.dtype = torch.float16, ) -> tuple[int, int]: + """ + The memory footprint depends on the cache size C and the max batch tokens M in the following way: + Mem = Mem(cache) + Mem(activation) + Mem(static_tensors) + where: + Mem(cache) = 2 * num_heads * head_dim * num_layers * cache_dtype.itemsize * C + Mem(activation) = M * (hidden_size + vocab_size) * activation_dtype.itemsize + Mem(static_tensors) ~= 8M * input_dtype.itemsize + M * C * activation_dtype.itemsize + + Depending on if C or M is given, we use different methods to infer the values (C = num_blocks * block_size) and + since block_size is fixed, num_blocks is the true variable to find. + """ # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial if num_blocks is None and max_batch_tokens is None: num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens( @@ -262,14 +273,22 @@ def compute_num_blocks_and_max_batch_tokens( self, max_memory_percent: float = 0.9, cache_dtype: torch.dtype = torch.float16, - m: float = 0.1, + m: float = 0.01, ) -> tuple[int, int]: + """ + If neither M nor C is given, we assume M = m*C so we have to solve a second-order polynomial in C: + Mem = C * 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize + + C * m * (hidden_size + vocab_size) * activation_dtype.itemsize + + C * m * 8 * input_dtype.itemsize + C^2 * m * activation_dtype.itemsize + + We solve for C and then M = m*C. + """ cache_memory = self.get_available_memory(max_memory_percent) logger.info(f"Cache memory: {cache_memory}") - # Compute memory footprints # TODO: check and explain better + # Compute memory footprints mem_per_activation_token = ( - self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + m * self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) ) mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize mem_per_input_token = 8 * m * self._input_dtype.itemsize @@ -279,8 +298,8 @@ def compute_num_blocks_and_max_batch_tokens( # Compute second-degree polynomial coefficients a = m * self._activation_dtype.itemsize - b = mem_per_input_token + mem_per_cache_token - c = mem_per_activation_token + 2 * self._input_dtype.itemsize - cache_memory + b = mem_per_input_token + mem_per_cache_token + mem_per_activation_token + c = - cache_memory # Compute discriminant and greatest solution discriminant = b**2 - 4 * a * c @@ -307,15 +326,20 @@ def compute_max_batch_tokens( max_memory_percent: float = 0.9, cache_dtype: torch.dtype = torch.float16, ) -> int: + """ + If C is given, we have a formula for M: + num = (Mem - C * 2 * num_heads * head_dim * num_layers * cache_dtype.itemsize) + denum = (8 * input_dtype.itemsize + C * activation_dtype.itemsize + (hidden_size + vocab_size) * activation_dtype.itemsize) + M = num / denum + """ cache_memory = self.get_available_memory(max_memory_percent) cache_size = num_blocks * self.block_size # Compute numerator num = cache_memory - num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor - num -= 2 * self._input_dtype.itemsize num -= cache_size * 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize # Compute denominator denum = 8 * self._input_dtype.itemsize + cache_size * self._activation_dtype.itemsize + denum += (self.hidden_size + self.vocab_size) * self._activation_dtype.itemsize # Compute max batch tokens and return return int(num / denum) @@ -325,12 +349,17 @@ def compute_num_blocks( max_memory_percent: float = 0.9, cache_dtype: torch.dtype = torch.float16, ) -> int: + """ + If M is given, we have a formula for C: + num = Mem - M * (hidden_size + vocab_size) * activation_dtype.itemsize - 8 * M * input_dtype.itemsize + denum = 2 * num_heads * head_dim * num_layers * cache_dtype.itemsize + M * activation_dtype.itemsize + C = num / denum + """ cache_memory = self.get_available_memory(max_memory_percent) # Compute numerator num = cache_memory - num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor + num -= self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * max_batch_tokens num -= 8 * max_batch_tokens * self._input_dtype.itemsize - num -= 2 * self._input_dtype.itemsize # Compute denominator denum = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize denum += max_batch_tokens * self._activation_dtype.itemsize @@ -346,7 +375,7 @@ def compute_memory_footprint( ) -> tuple[int, int, int]: # Compute activation memory footprint activation_memory_footprint = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) - activation_memory_footprint *= self._activation_safety_factor + activation_memory_footprint *= max_batch_tokens # Compute cache memory footprint if num_blocks is provided if num_blocks is not None: cache_size = num_blocks * self.block_size @@ -360,7 +389,7 @@ def compute_memory_footprint( [ 3 * max_batch_tokens * self._input_dtype.itemsize, # input_ids, position_ids, output_ids max_batch_tokens * cache_size * self._activation_dtype.itemsize, # attention_mask - 2 * (max_batch_tokens + 1) * self._input_dtype.itemsize, # cumulative_seqlens_qk + 2 * max_batch_tokens * self._input_dtype.itemsize, # cumulative_seqlens_qk (we remove the +1 to M) 3 * max_batch_tokens * self._input_dtype.itemsize, # write_index, read_index, logits_indices ] ) From 0de06e30564b1d9d2de225ffd6472ab66ed18fe6 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 16:18:14 +0000 Subject: [PATCH 23/26] Added a simple example --- .../pytorch/continuous_batching_simple.py | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 examples/pytorch/continuous_batching_simple.py diff --git a/examples/pytorch/continuous_batching_simple.py b/examples/pytorch/continuous_batching_simple.py new file mode 100644 index 000000000000..5bb7c6b20070 --- /dev/null +++ b/examples/pytorch/continuous_batching_simple.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import time + +import datasets +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + + +MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" +DISPLAYED_SAMPLES = 3 + + +if __name__ == "__main__": + # Parse args + parser = argparse.ArgumentParser() + parser.add_argument("--num-blocks", "-n", type=int, default=None) + parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) + parser.add_argument( + "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" + ) + parser.add_argument("--samples", type=int, default=500) + args = parser.parse_args() + + # Prepare model + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + attn_implementation=args.attn, + dtype=torch.bfloat16, + ) + model = model.cuda().eval() + + # Prepare tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") + dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") + dataset = dataset.select(range(args.samples)) + tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True) + simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + + # Prepare generation config + generation_config = GenerationConfig( + max_new_tokens=512, + use_cuda_graph=False, # Not supported for simple version + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + num_blocks=args.num_blocks, + max_batch_tokens=args.max_batch_tokens, + ) + + # Warmup iterations + _ = model.generate_batch( + inputs=simple_batch_inputs[: min(5, args.samples)], + generation_config=generation_config, + slice_inputs=True, + ) + + # Actual batch generation + print("--- Running CB Generation Example ---") + start_time = time.time() + batch_outputs = model.generate_batch( + inputs=simple_batch_inputs, + generation_config=generation_config, + slice_inputs=True, + ) + end_time = time.time() + print("Done with batch generation.") + + # Decode outputs + token_count = 0 + for i, request in enumerate(batch_outputs): + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True) + # Try to decode the output + try: + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True) + token_count += len(batch_outputs[request].generated_tokens[1:]) + except Exception as e: + print(f"Decoding failed for request {request}: {e}") + continue + + # Display sample if asked + if i < DISPLAYED_SAMPLES: + print("-" * 20) + print(f"{request} Input: {input_text}") + if len(output_text) > 0: + print(f"{request} Output: {output_text}") + else: + print(f"[WARN] {request} Output was empty!") + + # Compute stats and maybe print them + gen_time = end_time - start_time + tok_per_sec = token_count / gen_time + print("-" * 20) + print("--- Finished CB Generation Example ---\n") + print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s") From 8325b3768adc8cd5fb45d910505673098be56f44 Mon Sep 17 00:00:00 2001 From: remi-or Date: Mon, 25 Aug 2025 16:21:26 +0000 Subject: [PATCH 24/26] Renamed core to classes --- src/transformers/generation/continuous_batching/__init__.py | 2 +- src/transformers/generation/continuous_batching/cache.py | 2 +- .../generation/continuous_batching/{core.py => classes.py} | 0 .../generation/continuous_batching/continuous_api.py | 2 +- src/transformers/generation/continuous_batching/scheduler.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename src/transformers/generation/continuous_batching/{core.py => classes.py} (100%) diff --git a/src/transformers/generation/continuous_batching/__init__.py b/src/transformers/generation/continuous_batching/__init__.py index 1d939594d64b..0c0978bce0da 100644 --- a/src/transformers/generation/continuous_batching/__init__.py +++ b/src/transformers/generation/continuous_batching/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. from .cache import PagedAttentionCache from .continuous_api import ContinuousBatchingManager, ContinuousMixin -from .core import RequestState, RequestStatus +from .classes import RequestState, RequestStatus __all__ = ["PagedAttentionCache", "RequestState", "RequestStatus", "ContinuousMixin", "ContinuousBatchingManager"] diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index cdc228075c22..fb67c5ed5166 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -21,7 +21,7 @@ from ...configuration_utils import PretrainedConfig from ...generation.configuration_utils import GenerationConfig from ...utils.metrics import attach_tracer, traced -from .core import RequestState, get_device_and_memory_breakdown, logger +from .classes import RequestState, get_device_and_memory_breakdown, logger @attach_tracer() diff --git a/src/transformers/generation/continuous_batching/core.py b/src/transformers/generation/continuous_batching/classes.py similarity index 100% rename from src/transformers/generation/continuous_batching/core.py rename to src/transformers/generation/continuous_batching/classes.py diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 9e629f30f316..00bc43af7157 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -30,7 +30,7 @@ from ...utils.logging import logging from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced from .cache import PagedAttentionCache -from .core import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger +from .classes import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py index b08bd812fe33..9f612c9380ff 100644 --- a/src/transformers/generation/continuous_batching/scheduler.py +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -17,7 +17,7 @@ from ...utils.metrics import attach_tracer, traced from .cache import PagedAttentionCache -from .core import RequestState, RequestStatus +from .classes import RequestState, RequestStatus class Scheduler(ABC): From 7dee44e47bc8d824e16c1cef31187b925648eeb2 Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 26 Aug 2025 07:49:33 +0000 Subject: [PATCH 25/26] Made allocation of attention mask optionnal --- .../continuous_batching/continuous_api.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 00bc43af7157..e05a9a5096e5 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -37,7 +37,7 @@ @dataclass class PagedAttentionArgs: input_ids: torch.Tensor - attention_mask: torch.Tensor + attention_mask: Optional[torch.Tensor] position_ids: torch.Tensor cumulative_seqlens_q: torch.Tensor cumulative_seqlens_k: torch.Tensor @@ -105,6 +105,9 @@ def __init__( self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.config._name_or_path) self.decode_stream = DecodeStream(skip_special_tokens=True) + def return_attention_mask(self) -> bool: + return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call + @traced(standalone=True) def setup_static_tensors(self): T = self.max_batch_tokens @@ -114,9 +117,6 @@ def setup_static_tensors(self): self.tensor_metadata = tensor_metadata self.input_ids = torch.empty((1, T), **tensor_metadata) self.position_ids = torch.empty((1, T), **tensor_metadata) - self.attention_mask = torch.empty( - (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device - ) self.cumulative_seqlens_q = torch.empty((T + 1,), **tensor_metadata) self.cumulative_seqlens_k = torch.empty((T + 1,), **tensor_metadata) self.write_index = torch.empty((T,), **tensor_metadata) @@ -125,6 +125,13 @@ def setup_static_tensors(self): self.max_seqlen_q = 0 self.max_seqlen_k = 0 self.output_ids = torch.empty((1, T), **tensor_metadata) + # Since attenention_mask is not always needed, we only allocate it if it is needed + if self.return_attention_mask(): + self.attention_mask = torch.empty( + (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device + ) + else: + self.attention_mask = None # Initialize the tensors by pretending they are in full use self.actual_tokens = T self.cache_used = max_token_budget @@ -143,7 +150,6 @@ def reset_static_tensors(self): # Reset the tensors self.input_ids[:, :t].zero_() self.position_ids[:, :t].zero_() - self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min) self.cumulative_seqlens_q[: t + 1].zero_() self.cumulative_seqlens_k[: t + 1].zero_() self.write_index[:t].fill_(-1) @@ -152,17 +158,20 @@ def reset_static_tensors(self): self.max_seqlen_q = 0 self.max_seqlen_k = 0 self.output_ids[:, :t].fill_(-1) + if self.attention_mask is not None: + self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min) + def get_model_kwargs(self) -> PagedAttentionArgs: """Get model keyword arguments for the current batch.""" # Compute the slice to return t = self.actual_tokens if self.slice_inputs else self.write_index.size(0) c = self.cache_used if self.slice_inputs else self.read_index.size(0) - # Return the tensors - return { + # Prepare the kwargs + kwargs = { "input_ids": self.input_ids[:, :t], + "attention_mask": self.attention_mask, "position_ids": self.position_ids[:, :t], - "attention_mask": self.attention_mask[:, :, :t, :c], # NOTE: this is probably not used for paged attention "cu_seq_lens_q": self.cumulative_seqlens_q[:t+1], "cu_seq_lens_k": self.cumulative_seqlens_k[:t+1], "write_index": self.write_index[:t], @@ -174,6 +183,10 @@ def get_model_kwargs(self) -> PagedAttentionArgs: "cache": self.cache, "use_cache": False, } + # If the attention mask is not None, we slice it as the others + if self.attention_mask is not None: + kwargs["attention_mask"] = self.attention_mask[:, :, :t, :c] + return kwargs def __repr__(self): return ( @@ -303,7 +316,7 @@ def _build_tensors( self.cache_used = len(read_index) min_value = torch.finfo(self.model_dtype).min - if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call` + if self.attention_mask is not None: for i in range(len(cumulative_seqlens_q) - 1): if ( cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] From 3f17daf89935cfac6c1b70500cf9bd5720f2c14e Mon Sep 17 00:00:00 2001 From: remi-or Date: Tue, 26 Aug 2025 07:51:10 +0000 Subject: [PATCH 26/26] Style --- examples/pytorch/continuous_batching.py | 4 +++- examples/pytorch/continuous_batching_simple.py | 4 ++-- .../generation/continuous_batching/__init__.py | 2 +- src/transformers/generation/continuous_batching/cache.py | 8 +++----- .../generation/continuous_batching/continuous_api.py | 7 +++---- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 731916dbd1b3..b5ad94ed3f11 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -241,7 +241,9 @@ def batch_generate( if args.output_file is None: os.makedirs("runs/cb", exist_ok=True) attn = args.attn.replace("|", "_").replace("/", "_") - args.output_file = f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json" + args.output_file = ( + f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json" + ) # Run warmup batch generation batch_generate( diff --git a/examples/pytorch/continuous_batching_simple.py b/examples/pytorch/continuous_batching_simple.py index 5bb7c6b20070..3ae5e3d83870 100644 --- a/examples/pytorch/continuous_batching_simple.py +++ b/examples/pytorch/continuous_batching_simple.py @@ -48,14 +48,14 @@ # Prepare tokenizer and dataset tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") - dataset = dataset.select(range(args.samples)) + dataset = dataset.select(range(args.samples)) tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True) simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] # Prepare generation config generation_config = GenerationConfig( max_new_tokens=512, - use_cuda_graph=False, # Not supported for simple version + use_cuda_graph=False, # Not supported for simple version eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, do_sample=False, diff --git a/src/transformers/generation/continuous_batching/__init__.py b/src/transformers/generation/continuous_batching/__init__.py index 0c0978bce0da..11d15b6468e2 100644 --- a/src/transformers/generation/continuous_batching/__init__.py +++ b/src/transformers/generation/continuous_batching/__init__.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from .cache import PagedAttentionCache -from .continuous_api import ContinuousBatchingManager, ContinuousMixin from .classes import RequestState, RequestStatus +from .continuous_api import ContinuousBatchingManager, ContinuousMixin __all__ = ["PagedAttentionCache", "RequestState", "RequestStatus", "ContinuousMixin", "ContinuousBatchingManager"] diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index fb67c5ed5166..dfc10859b41e 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import deque from math import floor, sqrt -from typing import Any, Optional, TypeVar, Union +from typing import Optional, Union import torch @@ -287,9 +287,7 @@ def compute_num_blocks_and_max_batch_tokens( logger.info(f"Cache memory: {cache_memory}") # Compute memory footprints - mem_per_activation_token = ( - m * self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) - ) + mem_per_activation_token = m * self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize mem_per_input_token = 8 * m * self._input_dtype.itemsize logger.info(f"Memory per activation token: {mem_per_activation_token}") @@ -299,7 +297,7 @@ def compute_num_blocks_and_max_batch_tokens( # Compute second-degree polynomial coefficients a = m * self._activation_dtype.itemsize b = mem_per_input_token + mem_per_cache_token + mem_per_activation_token - c = - cache_memory + c = -cache_memory # Compute discriminant and greatest solution discriminant = b**2 - 4 * a * c diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index e05a9a5096e5..4b6775141362 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -106,7 +106,7 @@ def __init__( self.decode_stream = DecodeStream(skip_special_tokens=True) def return_attention_mask(self) -> bool: - return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call + return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call @traced(standalone=True) def setup_static_tensors(self): @@ -161,7 +161,6 @@ def reset_static_tensors(self): if self.attention_mask is not None: self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min) - def get_model_kwargs(self) -> PagedAttentionArgs: """Get model keyword arguments for the current batch.""" # Compute the slice to return @@ -172,8 +171,8 @@ def get_model_kwargs(self) -> PagedAttentionArgs: "input_ids": self.input_ids[:, :t], "attention_mask": self.attention_mask, "position_ids": self.position_ids[:, :t], - "cu_seq_lens_q": self.cumulative_seqlens_q[:t+1], - "cu_seq_lens_k": self.cumulative_seqlens_k[:t+1], + "cu_seq_lens_q": self.cumulative_seqlens_q[: t + 1], + "cu_seq_lens_k": self.cumulative_seqlens_k[: t + 1], "write_index": self.write_index[:t], "read_index": self.read_index[:c], "logits_indices": self.logits_indices[:t],