Skip to content

[CB] [Major] Add CPU request offloading#45184

Merged
remi-or merged 14 commits intomainfrom
cb-cpu-offload
Apr 27, 2026
Merged

[CB] [Major] Add CPU request offloading#45184
remi-or merged 14 commits intomainfrom
cb-cpu-offload

Conversation

@remi-or
Copy link
Copy Markdown
Collaborator

@remi-or remi-or commented Apr 2, 2026

Summary

This PR adds CPU offloading to continuous batching. It's in raft until perf and test status are reported.

When the GPU KV cache is full and a request must be evicted, we check if there is enough VRAM to copy the request's KV cache to a pre-allocated pinned CPU buffer instead of discarding them. Otherwise we use the legacy "soft reset" path that forces a full re-prefill on restore. This avoids redundant compute when cache pressure is temporary.

Main additions:

  • New OffloadingManager class (offloading_manager.py, ~300 lines) owns a static CPU swap pool of pinned tensors, performs GPU↔CPU block copies, and falls back to soft reset when the pool is full.
  • Two new ContinuousBatchingConfig parameters: cpu_offload_space (GiB, default 0 = disabled) and cpu_offload_space_safety_threshold (fraction of system RAM cap, default 0.8).
  • The scheduler prioritizes restoring CPU-offloaded requests over prefilling fresh ones (_get_waiting_candidates), since a CPU→GPU copy is cheaper than a full prefill.
  • When prefix sharing is on, restored blocks are marked incomplete so mark_shareable_blocks_as_complete re-hashes and deduplicates them against existing shared blocks, avoiding permanent memory waste.
  • Compatible with async batching (the update_batch guard skips PENDING requests that were offloaded between scheduling and update), sliding window models (per-group block counts are tracked and the rolling-buffer arithmetic is
    preserved), and forking.

Performances

As expected, no major difference for generation without offloading:

Arguments Main (tok/s) Current (tok/s) Diff (%)
--samples 10 854.03 853.33 -0.1%
--samples 20 --num-blocks 20 508.44 508.62 +0.0%
--samples 50 3543.77 3623.23 +2.2%
--samples 100 5324.99 5277.82 -0.9%
--samples 100 --attn flash_attention_2 3647.27 3649.99 +0.1%
--samples 100 --attn sdpa 1027.38 1035.21 +0.8%
--samples 500 --no-use-async 6609.49 6587.37 -0.3%
--samples 500 --use-async 7903.19 7913.05 +0.1%
--samples 32 --max-new-tokens 2048 --use-async 2039.83 2039.52 -0.0%
--samples 32 --max-new-tokens 2048 --use-async --block-table 32 2653.75 2654.51 +0.0%
--samples 500 --add-prefix --compile 8521.59 8686.75 +1.9%
--samples 50 --num-return-sequences 8 --do-sample 916.43 912.93 -0.4%
--samples 100 --num-return-sequences 4 --do-sample 1808.55 1801.04 -0.4%

Tests

The following tests pass:

tests/generation/test_continuous_batching.py 
tests/cli/test_serve.py
tests/generation/test_paged_attention.py

@remi-or remi-or self-assigned this Apr 2, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@remi-or remi-or force-pushed the cb-cpu-offload branch 2 times, most recently from 298b74e to fc4bafb Compare April 15, 2026 04:14
@remi-or remi-or marked this pull request as ready for review April 21, 2026 07:18
@remi-or remi-or requested a review from ArthurZucker April 21, 2026 07:18
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! is this benched as well?

Comment on lines +344 to +346
# Also free CPU-offloaded cache for cancelled states
for state in cancelled_states:
self.offloading_manager.free_request_cpu_cache(state)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Also free CPU-offloaded cache for cancelled states
for state in cancelled_states:
self.offloading_manager.free_request_cpu_cache(state)
self.offloading_manager.free_requests(cancelled_states)

api wise isolate code ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand sorry!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean can you wrap the looping on the cancelled states inside the free requests api? re my comments about memory transfer, you are better of with the full list

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not change much, this is a CPU only operation. Not even torch intervenes. Added a note about this.

Comment thread src/transformers/generation/continuous_batching/continuous_api.py
Comment thread src/transformers/generation/continuous_batching/offloading_manager.py Outdated
Comment thread src/transformers/generation/continuous_batching/offloading_manager.py Outdated
Comment on lines +154 to +156
def _stream_ctx(self):
"""Returns a context manager that runs enclosed ops on the compute stream, or a no-op when none is set."""
return torch.cuda.stream(self._compute_stream) if self._compute_stream is not None else nullcontext()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this or the offloading manager being gated for this capability (but I mean for mps you have shared memory and IDK if neuron tpu etc have the same API)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean: this is not a no-op only if there is already a compute stream around.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just mean that torch.cuda will error out for mps!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I'm pretty sure we are safe because the ternary will only evaluate the CUDA part if there is a compute stream that's not None, and that will not happen on MPS.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

"""Returns a context manager that runs enclosed ops on the compute stream, or a no-op when none is set."""
return torch.cuda.stream(self._compute_stream) if self._compute_stream is not None else nullcontext()

def offload_one_request(self) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to copy in a bulk after we just compute which pages we need to copy imo!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what happen I think

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so no, each copy_ is a DMA / copy request, not async unless requested no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha yes I thought this was about restore. For the async part, I think it's ok because there is no async mode yet. As for the "bulk" part, currently we only offload one request at a time, so we group as much as possible but it's still the cache for one request only.

@remi-or remi-or added this pull request to the merge queue Apr 27, 2026
Merged via the queue into main with commit 5d75528 Apr 27, 2026
29 checks passed
@remi-or remi-or deleted the cb-cpu-offload branch April 27, 2026 08:35
@stevhliu stevhliu mentioned this pull request Apr 27, 2026
ArthurZucker pushed a commit that referenced this pull request Apr 28, 2026
* Stacked commits

* Fix compile and prcessoirs

* Review compliance

* Added None for cpu pool swap size

* TODO and disable

* First fixes

* Small fixes

* Fixes for 2s (1/n)

* Fixes for 2s (2/n)

* Fix for 2s (3/n)

* End of fixes (4/4)

* review compliance

* nits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants