Skip to content

fix(continuous-batching): apply logits processors in packed batches#43457

Closed
floor-licker wants to merge 7 commits intohuggingface:mainfrom
floor-licker:fix/cb-logits-processors
Closed

fix(continuous-batching): apply logits processors in packed batches#43457
floor-licker wants to merge 7 commits intohuggingface:mainfrom
floor-licker:fix/cb-logits-processors

Conversation

@floor-licker
Copy link
Copy Markdown

Summary

It seems like this was a known issue, but the continuous batching implementation packs multiple sequences into a single token stream even though most generation-time logits processors assume [batch, seq_len] / [batch, vocab] shapes. In practice it looks like processors are either broken or disabled in transformers serve and the tests. This change would apply logits processors per-request on the correct next token logits positions and removes the workarounds.

Logic:

  • Apply logits processors only at “next token” positions indicated by
    logits_indices, and only for requests that are in decode
  • For each decoding request, rebuild the per-request token history
    (state.initial_tokens + state.generated_tokens) and apply the processor to
    logits[0, position] (next-token scores) in-place

Tests run

  • python -m pytest tests/generation/test_continuous_batching.py -q
  • python -m pytest tests/generation/test_logits_process.py -k
    repetition_penalty_continuous_batching -q
  • python -m pytest tests/cli/test_serve.py -q
  • RUN_SLOW=1 python -m pytest tests/cli/test_serve.py -k
    ServeCompletionsContinuousBatchingIntegrationTest -q

@floor-licker floor-licker force-pushed the fix/cb-logits-processors branch from 203d9d7 to af1fed6 Compare January 23, 2026 23:32
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43457&sha=af1fed

@floor-licker
Copy link
Copy Markdown
Author

Some of these CI failures seem a bit odd relative to the code changes, may take me a minute to track this down.

@floor-licker
Copy link
Copy Markdown
Author

Hey @LysandreJik , just wondering if I can get some feedback on this. As far as I can tell the failing tests are just flaky. Can only intermittently reproduce the failures locally.

@Rocketknight1
Copy link
Copy Markdown
Member

cc @McPatate @remi-or for continuous batching

Copy link
Copy Markdown
Collaborator

@remi-or remi-or left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! There are some large points to address here, as this is not a trivial part of the continuous batching code. It would be good if we can support this in the most optimized way possible. Let me know if my comments are clear!

@traced(span_name="logit_processing")
def _process_logit(self, batch_data: dict, logits: torch.Tensor, logit_processor: LogitsProcessor) -> torch.Tensor:
# Pass continuous batching context to logits processor if it supports it.
if isinstance(logit_processor, list) and len(logit_processor) == 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should always be a list: you can check _get_logits_processor

Comment thread src/transformers/generation/continuous_batching/continuous_api.py Outdated
Comment thread src/transformers/generation/continuous_batching/continuous_api.py Outdated
Comment thread src/transformers/generation/continuous_batching/continuous_api.py Outdated
)
if isinstance(self.logit_processor, list) and len(self.logit_processor) > 0:
# Processors need eager.
if self.use_cuda_graph:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is not the optimal solution: the process_logits phase and what happens afterwards should be outside of the cuda graph / compile if there are processors. Do you think you could rewrite it that way?

Comment thread tests/cli/test_serve.py
# System prompt applied.
# Expect "sports" mention.
self.assertTrue(full_text.strip())
self.assertIn("sports", full_text.lower())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we should change this test!

logits = torch.ones((1, 4, vocab_size), device=torch_device, dtype=torch.float)

# Req0 next-token @2.
logits[0, 2, 1] = -2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you can hardcode this whole part as one tensor declaration

self.assertTrue(torch.equal(rep_penalty.logits_indices, batch_data["logits_indices"]))

# Req0 token penalties.
self.assertAlmostEqual(processed_logits[0, 2, 1].item(), -4.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do bactched checks?

architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_kwargs)

# Default `max_length` is 20.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments are weird: did you use some AI assistant for this?

@floor-licker
Copy link
Copy Markdown
Author

@remi-or Thanks a lot for your feedback, its my first contribution to transformers so I'll take a minute to review your comments and get back to you.

@JLenzy
Copy link
Copy Markdown

JLenzy commented Mar 4, 2026

Hello! I just discovered this PR as I'm looking to add per-request logit processors in a production continuous batching setting. @remi-or do you know if this is currently supported?

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Mar 5, 2026

Hello! I just discovered this PR as I'm looking to add per-request logit processors in a production continuous batching setting. @remi-or do you know if this is currently supported?

I am not sure logits processors are supported in the first place, and I am sure per-request is not supported. Not sure it will ever be in production settings, where you need CUDA graphs, which makes per-request treatment hard in (most) cases. What logits processors do you have in mind? I can take a look.

@JLenzy
Copy link
Copy Markdown

JLenzy commented Mar 9, 2026

Hello! I just discovered this PR as I'm looking to add per-request logit processors in a production continuous batching setting. @remi-or do you know if this is currently supported?

I am not sure logits processors are supported in the first place, and I am sure per-request is not supported. Not sure it will ever be in production settings, where you need CUDA graphs, which makes per-request treatment hard in (most) cases. What logits processors do you have in mind? I can take a look.

Thanks for your attention to this!
My team is preparing to use continuous batching to deploy a model for widespread use. During internal testing, we have received a lot of requests for 'weirder' outputs - we work in generative music. In research environments (single batch generations), we have been able to expose Top-p/top-k/temperature inference settings, which has allowed for a lot of creative exploration. But my current understanding is that we need to instantiate a single GenerationConfig and then run that across the board for all generations?

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Mar 13, 2026

But my current understanding is that we need to instantiate a single GenerationConfig and then run that across the board for all generations?

Yes, that's the case currently. Maybe we can complexify this for some logits processors, if pytorch operators permit it. That would only be those three parameters?

@JLenzy
Copy link
Copy Markdown

JLenzy commented Mar 13, 2026

In an 'ideal world' we could modify per generation:

  • top_p
  • min_p
  • top_k
  • temperature
  • repetition_penalty
  • no_repeat_ngram_size

But top-p/k/temp as a starting point would already be immensely helpful. I would be happy to help contribute to this effort, bearing in mind that it'd be my first contribution to the HF ecosystem :)

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Mar 26, 2026

Hey @JLenzy , we are adding this feature in #45026
The draft will probably be removed tomorrow. Dont hesitate to let us know if you have early feedback.

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Apr 7, 2026

Closed thanks to #45026

@remi-or remi-or closed this Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants