Skip to content

[CB] Changes for long generation#45530

Merged
ArthurZucker merged 17 commits intomainfrom
cb-very-long-gen
Apr 23, 2026
Merged

[CB] Changes for long generation#45530
ArthurZucker merged 17 commits intomainfrom
cb-very-long-gen

Conversation

@remi-or
Copy link
Copy Markdown
Collaborator

@remi-or remi-or commented Apr 20, 2026

Summary

This PR fixes some issues related to memory to make long generation (16K+) easier.

  • Fix KV dedup for decode batches (scheduler.py): Decode-only batches don't consume the read_indices cache budget, so don't reject them on that basis. Also gate the decode fast path on max_blocks_per_request > 0 instead of unconditionally enabling it.
  • Fix memory estimation (requests.py): Use torch.cuda.mem_get_info instead of device_properties().total_memory, which ignored CUDA context/driver overhead (~0.5 GiB) and caused overcommit/OOM.
  • Raise max_memory_percent default 0.8 → 0.9 (configuration_utils.py): Now safe with the corrected memory accounting above.
  • Write-only fast path (cache.py, input_outputs.py, scheduler.py): When a batch has no past-cache reads (pure prefills), skip the index_select read-back, avoid allocating/transferring read_index, and return the input KV states directly. Also adjusts the CUDA-graph key to depend on the block-table path rather than max_kv_read > 0.
  • Two-peak memory model (cache.py): Replace the single peak_activation_per_token with two activation peaks — LM head (hidden + logits, N-independent) and attention (hidden + Q + new K/V + cache K/V reads, grows with N). Solve the memory polynomial for each peak independently and take the most restrictive (num_blocks, max_batch_tokens). Bumps _upper_bound_max_batch_tokens 256 → 1024.

Performances

Pretty good, lot of workloads benefit from 80% to 90% raise in cache space

Arguments Main (tok/s) Current (tok/s) Diff (%)
--samples 10 869.0 890.37 +2.5%
--samples 20 --num-blocks 20 517.95 520.16 +0.4%
--samples 50 3629.88 3638.6 +0.2%
--samples 100 5375.41 5522.83 +2.7%
--samples 100 --attn flash_attention_2 3666.82 3743.47 +2.1%
--samples 100 --attn sdpa 1030.21 1053.57 +2.3%
--samples 500 --no-use-async 6621.78 8020.64 +21.1%
--samples 500 --use-async 7963.66 9332.71 +17.2%
--samples 32 --max-new-tokens 2048 --use-async 2033.87 2064.29 +1.5%
--samples 32 --max-new-tokens 2048 --use-async --block-table 32 2716.64 2734.81 +0.7%
--samples 500 --add-prefix --compile 7649.48 8882.62 +16.1%
--samples 50 --num-return-sequences 8 --do-sample 869.94 980.13 +12.7%
--samples 100 --num-return-sequences 4 --do-sample 1708.88 1925.25 +12.7%

Tests

  • make style and make_typing pass
  • RUN_SLOW=1 pytest tests/generation/test_continuous_batching.py
  • RUN_SLOW=1 pytest tests/cli/test_serve.py
  • RUN_SLOW=1 pytest tests/generation/test_paged_attention.py

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@remi-or remi-or marked this pull request as ready for review April 21, 2026 15:40
@remi-or remi-or requested a review from ArthurZucker April 21, 2026 15:40
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the names are a tad bit unfamiliar to me, but LGTM!

total_memory = torch.cuda.get_device_properties(device).total_memory
# Use mem_get_info to get actual free memory: device_properties().total_memory returns the physical device
# total which ignores CUDA context and driver overhead (~0.5 GiB), leading to overcommit.
free_memory, total_memory = torch.cuda.mem_get_info(device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@ArthurZucker ArthurZucker merged commit 07e3831 into main Apr 23, 2026
27 of 29 checks passed
@ArthurZucker ArthurZucker deleted the cb-very-long-gen branch April 23, 2026 09:34
tarekziade pushed a commit that referenced this pull request Apr 23, 2026
* Fix KV dedup for decode batches

* Fix memory estimation

* Change default

* Added write-only fast path

* Take both peaks into account

* Revert unused config field

* Review 1

* Fix p1s

* Fix p2s and p3s that needed it

* Added a TODO

* Fix test, lower max cached graph, add TODO

* Fix fragmentation with big warmup

* Add more space for logits processors

* Fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants