fix(continuous-batching): apply logits processors in packed batches#43457
fix(continuous-batching): apply logits processors in packed batches#43457floor-licker wants to merge 7 commits intohuggingface:mainfrom
Conversation
203d9d7 to
af1fed6
Compare
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43457&sha=af1fed |
|
Some of these CI failures seem a bit odd relative to the code changes, may take me a minute to track this down. |
|
Hey @LysandreJik , just wondering if I can get some feedback on this. As far as I can tell the failing tests are just flaky. Can only intermittently reproduce the failures locally. |
There was a problem hiding this comment.
Thanks for the contribution! There are some large points to address here, as this is not a trivial part of the continuous batching code. It would be good if we can support this in the most optimized way possible. Let me know if my comments are clear!
| @traced(span_name="logit_processing") | ||
| def _process_logit(self, batch_data: dict, logits: torch.Tensor, logit_processor: LogitsProcessor) -> torch.Tensor: | ||
| # Pass continuous batching context to logits processor if it supports it. | ||
| if isinstance(logit_processor, list) and len(logit_processor) == 0: |
There was a problem hiding this comment.
This should always be a list: you can check _get_logits_processor
| ) | ||
| if isinstance(self.logit_processor, list) and len(self.logit_processor) > 0: | ||
| # Processors need eager. | ||
| if self.use_cuda_graph: |
There was a problem hiding this comment.
IMO this is not the optimal solution: the process_logits phase and what happens afterwards should be outside of the cuda graph / compile if there are processors. Do you think you could rewrite it that way?
| # System prompt applied. | ||
| # Expect "sports" mention. | ||
| self.assertTrue(full_text.strip()) | ||
| self.assertIn("sports", full_text.lower()) |
There was a problem hiding this comment.
Not sure why we should change this test!
| logits = torch.ones((1, 4, vocab_size), device=torch_device, dtype=torch.float) | ||
|
|
||
| # Req0 next-token @2. | ||
| logits[0, 2, 1] = -2.0 |
There was a problem hiding this comment.
Seems like you can hardcode this whole part as one tensor declaration
| self.assertTrue(torch.equal(rep_penalty.logits_indices, batch_data["logits_indices"])) | ||
|
|
||
| # Req0 token penalties. | ||
| self.assertAlmostEqual(processed_logits[0, 2, 1].item(), -4.0) |
There was a problem hiding this comment.
Why not do bactched checks?
| architecture = getattr(transformers, config.architectures[0]) | ||
| model = architecture.from_pretrained(model_id, **model_kwargs) | ||
|
|
||
| # Default `max_length` is 20. |
There was a problem hiding this comment.
These comments are weird: did you use some AI assistant for this?
|
@remi-or Thanks a lot for your feedback, its my first contribution to transformers so I'll take a minute to review your comments and get back to you. |
|
Hello! I just discovered this PR as I'm looking to add per-request logit processors in a production continuous batching setting. @remi-or do you know if this is currently supported? |
I am not sure logits processors are supported in the first place, and I am sure per-request is not supported. Not sure it will ever be in production settings, where you need CUDA graphs, which makes per-request treatment hard in (most) cases. What logits processors do you have in mind? I can take a look. |
Thanks for your attention to this! |
Yes, that's the case currently. Maybe we can complexify this for some logits processors, if pytorch operators permit it. That would only be those three parameters? |
|
In an 'ideal world' we could modify per generation:
But top-p/k/temp as a starting point would already be immensely helpful. I would be happy to help contribute to this effort, bearing in mind that it'd be my first contribution to the HF ecosystem :) |
|
Closed thanks to #45026 |
Summary
It seems like this was a known issue, but the continuous batching implementation packs multiple sequences into a single token stream even though most generation-time logits processors assume
[batch, seq_len]/[batch, vocab]shapes. In practice it looks like processors are either broken or disabled intransformers serveand the tests. This change would apply logits processors per-request on the correct next token logits positions and removes the workarounds.Logic:
logits_indices, and only for requests that are in decode(
state.initial_tokens + state.generated_tokens) and apply the processor tologits[0, position](next-token scores) in-placeTests run
repetition_penalty_continuous_batching -q
ServeCompletionsContinuousBatchingIntegrationTest -q