Skip to content

[CB] Better overall script and decode bucketting#45653

Open
remi-or wants to merge 9 commits intomainfrom
cb-overall-fixes
Open

[CB] Better overall script and decode bucketting#45653
remi-or wants to merge 9 commits intomainfrom
cb-overall-fixes

Conversation

@remi-or
Copy link
Copy Markdown
Collaborator

@remi-or remi-or commented Apr 27, 2026

Summary

This PR updates the continuous_batching_overall script to better reflect the way users use continuous batching (ie. using simpler or the default config) on realistic tasks (GSM8K or RL rollouts of various lengths).
It also changes the way decode-only batches are bucketed: instead of using linear intervals, these batches are now bucketed using powers of 2. This led to better latencies.
Decode batches are also automatically turned on unless the user explicitly disables them, using workload hints (ie. max sequence length)

Performances

New script, before PR:

label samples avg_in max_new time (s) tokens tok/s mem (GB)
gsm8k_default 1319 94 256 30.34 337664 11128 72.53
gsm8k_compile 1319 94 256 30.35 337664 11126.1 71.85
gsm8k_no_fast_decode 1319 94 256 30.38 337664 11114 71.88
rollouts_1024 32 256 1024 14.87 32768 2204.13 71.85
rollouts_2048 32 256 2048 33.6 65536 1950.19 71.88
rollouts_4096 32 256 4096 83.62 131072 1567.45 71.94
rollouts_8192 32 256 8192 235.02 262144 1115.42 71.87
rollouts_16384 32 256 16384 785.44 524288 667.51 71.84
few_blocks 20 256 256 8.26 5120 619.85 16.82
multi_return_seq 50 256 256 12.22 12800 1047.16 71.93

New script, after PR:

label samples avg_in max_new time (s) tokens tok/s mem (GB)
gsm8k_default 1319 94 256 27.16 337664 12432.9 72.53
gsm8k_sampling 1319 94 256 27.97 337664 12072.9 71.08
gsm8k_compile 1319 94 256 26.71 337664 12640.2 70.72
gsm8k_no_fast_decode 1319 94 256 26.86 337664 12571.7 70.01
rollouts_1024 32 256 1024 9.98 32768 3282.96 69.3
rollouts_2048 32 256 2048 20.93 65536 3131.43 68.59
rollouts_4096 32 256 4096 46.96 131072 2791.37 67.88
rollouts_8192 32 256 8192 115.64 262144 2266.85 67.17
rollouts_16384 32 256 16384 369.07 524288 1420.55 66.49
few_blocks 20 256 256 7.32 5120 699.1 16.88
multi_return_seq 50 256 256 8.25 12800 1551.81 65.68

This is mostly because of the auto-activation of the decode path IMO.

Tests

  • RUN_SLOW=1 pytest tests/generation/test_continuous_batching.py
  • RUN_SLOW=1 pytest tests/cli/test_serve.py
  • RUN_SLOW=1 pytest tests/generation/test_paged_attention.py

@remi-or remi-or requested a review from ArthurZucker April 27, 2026 07:40
@remi-or remi-or self-assigned this Apr 27, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Comment on lines +95 to +96
model = AutoModelForCausalLM.from_pretrained(self.model_id, attn_implementation=self.attn_impl)
model = model.to(device="cuda").eval() # type: ignore
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use device map auto?

self,
generation_config: GenerationConfig | None = None,
continuous_batching_config: ContinuousBatchingConfig | None = None,
workload_hints: dict[str, int] | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use typedDict here explicits which keys are accepted

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants