Skip to content

[CB] Add per-request logits processors#45026

Merged
remi-or merged 2 commits intomainfrom
cb-batched-logits
Apr 3, 2026
Merged

[CB] Add per-request logits processors#45026
remi-or merged 2 commits intomainfrom
cb-batched-logits

Conversation

@remi-or
Copy link
Copy Markdown
Collaborator

@remi-or remi-or commented Mar 26, 2026

Summary

This PR adds per-request logits processors and overalls the way CB handles logits processors.
It introduces batched logits processing with per-request parameters for continuous batching, enabling each request in a batch to use different sampling parameters (temperature, top_k, top_p). This is essential for serving scenarios where different users may request different generation configurations within the same batch.

The main changes are:

  • cb_logits_processors.py (new): Adds ContinuousBatchingLogitsProcessorList to manage logits processors and three per-request processor implementations (Temperature, TopK, TopP) that operate on the batched tensor format using vectorized operations
  • continuous_api.py: Integrates the new processor list into the batch processor, passes it through the forward/sampling pipeline, and validates per-request kwargs at request ingestion
  • input_outputs.py: Extends the bulk input tensor to store processor arguments (one row per processor), passes them to PagedAttentionArgs, and prepares tensor args during batch preparation
  • requests.py: Adds logit_processor_kwargs field to RequestState for per-request overrides, reorganizes dataclass fields, and simplifies fork() using deepcopy
  • configuration_utils.py: Adds per_request_processors and drop_unsupported_processors flags to ContinuousBatchingConfig
  • logits_process.py: Adds supports_continuous_batching flag to base processors for compatibility checking
  • Tests: Extends test coverage for per-request temperature, top_k, and top_p sampling

The processor list validates compatibility at init time, warns about unsupported processors, and efficiently prepares tensor arguments by storing them as views into the bulk input tensor to minimize memory transfers.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@remi-or remi-or force-pushed the cb-batched-logits branch from c8539d8 to fe01b1c Compare March 26, 2026 22:29
@remi-or remi-or marked this pull request as ready for review March 27, 2026 11:12
@remi-or remi-or requested a review from ArthurZucker March 27, 2026 11:12
@remi-or remi-or self-assigned this Mar 27, 2026
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Missing examples in doc? but happy to have otherwise lets reduce bload in ContinuousBatchingLogitsProcessorList


# Abstract base class for all continuous batching logits processors
class ContinuousBatchingLogitsProcessor(ABC):
supported_kwargs: tuple[str, ...] # Kwargs that this processor actively uses
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be a typedict

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Do you want to add type checking when passing args?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the kwargs you can list which ones are supporterd / processor is what I mean

Comment thread src/transformers/generation/continuous_batching/cb_logits_processors.py Outdated
Comment thread src/transformers/generation/continuous_batching/cb_logits_processors.py Outdated
Comment thread src/transformers/generation/continuous_batching/cb_logits_processors.py Outdated
Comment thread src/transformers/generation/continuous_batching/cb_logits_processors.py Outdated
Comment thread src/transformers/generation/continuous_batching/cb_logits_processors.py Outdated
Comment on lines +265 to +270
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@strict decorator validation for this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@strict is for config I think? I dont see any logits processor with strict.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for dataclasses not config necessarily no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But logits processors are not dataclasses? Anyway, we initialize it from the classic logit processor now, so an error in parameters will be caught by the classic logits processors before CB one is even initialized

Comment thread src/transformers/generation/continuous_batching/input_outputs.py Outdated
Comment thread tests/generation/test_continuous_batching.py Outdated
```
"""

supports_continuous_batching: bool = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not be needed if we have a mapping / list (we don't "pollute" unrelated classes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But any time a new logit processor is added, its support is set to None that way. A mapping we will have to manually keep updating which is less redundant. wdyt?

@remi-or remi-or force-pushed the cb-batched-logits branch from db1c672 to aace16a Compare April 3, 2026 09:03
@remi-or remi-or force-pushed the cb-batched-logits branch from 323e792 to 3a8bba1 Compare April 3, 2026 13:29
@remi-or remi-or added this pull request to the merge queue Apr 3, 2026
Merged via the queue into main with commit 82a73ad Apr 3, 2026
30 checks passed
@remi-or remi-or deleted the cb-batched-logits branch April 3, 2026 16:44
marvinzh pushed a commit to marvinzh/transformers that referenced this pull request Apr 3, 2026
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Apr 4, 2026
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants