Skip to content

Continuous batching thread safety#44924

Merged
ArthurZucker merged 20 commits intohuggingface:mainfrom
Qubitium:continuos-batching-paged-attention-threads
Mar 23, 2026
Merged

Continuous batching thread safety#44924
ArthurZucker merged 20 commits intohuggingface:mainfrom
Qubitium:continuos-batching-paged-attention-threads

Conversation

@Qubitium
Copy link
Copy Markdown
Contributor

What does this PR do?

Fix two nogil threading bugs (reproduced on 3.14) :

  1. Continus Batching crashes with torch graph errors with 2 threads on 2 separate model instances (same model path, but two distinct instances). Cause is missing capture_error_mode="thread_local" so the graph capture operate safely in thread_local mode.
with torch.cuda.graph(
                    graph,
                    stream=compute_stream,
                    pool=self.graph_pool,
                    capture_error_mode="thread_local",
                ):
  1. Model loading of same model path but 2 different threads (2 different instaances) have meta device tensor issues due to tie_weight() code failing to execute because to tie_weight skipping logic/state is not thread safe. Fixed using contextar scope state. Triggered with 2 threads loading at the same time with Llama 3.2 1B instruct.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @remi @McPatate

Please trigger ci to run the new unit tests. My own tests are passing that use this PR code.

@Qubitium Qubitium changed the title Continuos batching paged attention threads Continuous batching paged attention thread safety Mar 22, 2026
@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Mar 23, 2026

There is related issue, but outside the scope of this PR, where Continuos Batch generation can bind to the wrong cuda context.

The current deisgn assumes the traditional one-process, one model approach so the following generation code work without issue. Things breaks down in one process, mult-threads, multi-model instance. At that point, the following piece of code will deadlock/stall due to wrong cuda ctx.

The fix is easy if each model is only using 1 gpu. Just set the ctx to the model's device. But if 2 gpus is bound to a model, then I am not sure if there is a torch api that can actually assign a range device ctx to a thread?

So a potential partial fix to the following block of code, is to check if the model has weights on count(device) > 1 or tp > 1, if not true, we can go the safely wrap the following code in a single cuda:device ctx and this allows multi model, mult-threads, single gpu per model to execute correct using continuous batching. For tp > 1 and mult-gpu devices, warn users via documentation that this piece of code is unsuitable for nogil (regardless of thread saftey) and multi-model instances. Let me know if this partial fix is worthwhile for me to PR. in the meantime, I will just monkeypatch from my end.

    def _run_generation_loop(self) -> None:
        """Main processing loop running in the background thread."""
        batch_processor: ContinuousBatchProcessor | None = None
        try:
            t0 = perf_counter()
            paged_attention_cache = PagedAttentionCache(
                self.model.config,
                self.continuous_batching_config,
                self.model.device,
                self.model.dtype,
                tp_size=getattr(self.model, "_tp_size", None),  # Use model's actual TP setting
            )
            self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing  # update the approximation
            logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")

            scheduler = SCHEDULER_MAPPING.get(self.continuous_batching_config.scheduler, None)
            if scheduler is None:
                logger.warning(
                    f"Scheduler '{self.continuous_batching_config.scheduler}' not found. Defaulting to FIFO."
                )
                scheduler = FIFOScheduler

            t1 = perf_counter()
            batch_processor = ContinuousBatchProcessor(
                cache=paged_attention_cache,
                config=self.model.config,

My monkey patch wrapper:

def _patch_continuous_batching_manager_cuda_context_once(ContinuousBatchingManager: Any) -> None:
   run_generation_loop = getattr(ContinuousBatchingManager, "_run_generation_loop", None)
   if not callable(run_generation_loop):
       return

   with _CONTINUOUS_BATCHING_CUDA_CONTEXT_PATCH_LOCK:
       current = getattr(ContinuousBatchingManager, "_run_generation_loop", None)
       # Another session may have patched the class first; avoid wrapping the same method twice.
       if not callable(current) or getattr(current, "__evalution_cuda_context_patch__", False):
           return

       @wraps(current)
       def _wrapped_run_generation_loop(self: Any, *args: Any, **kwargs: Any) -> Any:
           import torch

           model_device = getattr(getattr(self, "model", None), "device", None)
           # The manager loop is the first code that runs on the background generation thread.
           # Enter the model's CUDA device here so any current-device CUDA calls made deeper in
           # transformers resolve against the manager's model instead of whatever device happened
           # to be current on that thread previously.
           maybe_device = (
               torch.cuda.device(model_device)
               if getattr(model_device, "type", None) == "cuda"
               else nullcontext()
           )
           with maybe_device:
               return current(self, *args, **kwargs)

       # Mark the wrapper so repeated manager construction in the same process stays idempotent.
       _wrapped_run_generation_loop.__evalution_cuda_context_patch__ = True
       ContinuousBatchingManager._run_generation_loop = _wrapped_run_generation_loop

@ArthurZucker ArthurZucker requested a review from remi-or March 23, 2026 09:14
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the CB review to @remi-or but

Model loading of same model path but 2 different threads (2 different instaances) have meta device tensor issues due to tie_weight() code failing to execute because to tie_weight skipping logic/state is not thread safe. Fixed using contextar scope state. Triggered with 2 threads loading at the same time with Llama 3.2 1B instruct.

Let's do this in a separate PR please 🤗

@Qubitium Qubitium changed the title Continuous batching paged attention thread safety Continuous batching thread safety Mar 23, 2026
@Qubitium
Copy link
Copy Markdown
Contributor Author

Leaving the CB review to @remi-or but

Model loading of same model path but 2 different threads (2 different instaances) have meta device tensor issues due to tie_weight() code failing to execute because to tie_weight skipping logic/state is not thread safe. Fixed using contextar scope state. Triggered with 2 threads loading at the same time with Llama 3.2 1B instruct.

Let's do this in a separate PR please 🤗

Done. #44940

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Mar 23, 2026

Hello @Qubitium ! Thanks for the "capture-error-mode" fix, I was not aware this was an issue, and the fix is clear and concise.
I think I understand the test a little less: I am not sure it is really testing for the issue that transpired before your fix. Could you replace your test with something based on the error you witnessed without your fix? Hence the test should fail without the fix and pass with the fix.
If this is too much work, I get it, you can just add a minimal reproducible example and I will adapt the test. Thanks!

@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Mar 23, 2026

Hello @Qubitium ! Thanks for the "capture-error-mode" fix, I was not aware this was an issue, and the fix is clear and concise. I think I understand the test a little less: I am not sure it is really testing for the issue that transpired before your fix. Could you replace your test with something based on the error you witnessed without your fix? Hence the test should fail without the fix and pass with the fix. If this is too much work, I get it, you can just add a minimal reproducible example and I will adapt the test. Thanks!

@remi-or I have cleanup the unit test to not test for things we really don't care about. The unit test only tests that our error_mode was correctly applied. But it does not actually test the concurrent continuous batch operations on two threads on two gpus (2 models) can run correctly without crash. This I have proved works in my own tests with real application (not fake unit tests).

Do you want me to wire up a dual thread, dual model, dual gpu unit test? This requires 3.14t or 3.13t with nogil enabled. I don't even know how to write ci in this framework that has this hw/sw depends.

For this ci unit test to work I need the following:

  1. 3.13t or 3.14t
  2. two cuda/nvidia gpu
  3. able to set env flags like PYTHON_NOGIL=0.

Is this above possible? If not, I cannot recreate a true, crash if fix is not applied unit test.

@Qubitium
Copy link
Copy Markdown
Contributor Author

@remi-or Another problem is even if I cook up the test. You need to apply this patch AND #44940 (since that fixes model loading). lol. Chicken and egg issue. I will still cook up the test so you can test it locally and you can figure something out with ci team.

# re-enable the GIL during import unless the interpreter itself starts with `PYTHON_GIL=0`.


MODEL_PATH = "/monster/data/model/Llama-3.2-1B-Instruct" # FIXME
Copy link
Copy Markdown
Contributor Author

@Qubitium Qubitium Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@remi-or I have validated on my own end with 3.14t + 2 nvidia gpu + 2 threads + PYTHON_GIL=0 using this script. Please modify this unit test so that it:

  1. has proper access to llama 3.2 1b instruct or similar if you can reproduce the crash with another model (it should be model agnostic).
  2. fix/adpt this unit test for the circle system you guys run that I have no clue about. Since it needs to select 2 gpus, have fa2 available (maybe not need), and env needs 3.13t/3.14t.
  3. CI must set PYTHON_GIL=0 before python startup since tokenizer would override nogil without this env flag.
PYTHON_GIL=0 /root/vm314t/bin/python -m unittest \
tests.generation.test_continuous_batching_llama32_nogil_repro.ContinuousBatchingLlama32NogilReproTest.test_two_thread_llama32_continuous_batching_cuda_graph_repro

With the patch on the current branch:

Result: OK in 8.472s.

Without patch:

Result: FAILED in 5.554s.

2026-03-23 14:36:52,685 - ContinuousBatchingLogger - ERROR - Error in generation loop: CUDA error: operation not permitted when stream is capturing
 Traceback (most recent call last):
   File "/tmp/transformers-no-threadlocal-MxpTSt/src/transformers/generation/continuous_batching/continuous_api.py", line 843, in _run_generation_loop
     self._inner_generation_loop(batch_processor)
   File "/tmp/transformers-no-threadlocal-MxpTSt/src/transformers/generation/continuous_batching/continuous_api.py", line 863, in _inner_generation_loop
     self._generation_step()
   File "/tmp/transformers-no-threadlocal-MxpTSt/src/transformers/generation/continuous_batching/continuous_api.py", line 794, in _generation_step
     self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
   File "/tmp/transformers-no-threadlocal-MxpTSt/src/transformers/generation/continuous_batching/continuous_api.py", line 470, in _generation_step
     with torch.cuda.graph(
   File "/root/vm314t/lib/python3.14t/site-packages/torch/cuda/graphs.py", line 265, in __exit__
     self.cuda_graph.capture_end()
   File "/root/vm314t/lib/python3.14t/site-packages/torch/cuda/graphs.py", line 126, in capture_end
     super().capture_end()
 torch.AcceleratorError: CUDA error: operation not permitted when stream is capturing

 Then the second manager sees the invalidated capture:

 torch.AcceleratorError: CUDA error: operation failed due to a previous error during capture
 ...
   File "/tmp/transformers-no-threadlocal-MxpTSt/src/transformers/generation/continuous_batching/continuous_api.py", line 470, in _generation_step
     with torch.cuda.graph(
 ...
 torch.AcceleratorError: CUDA error: operation failed due to a previous error during capture

 And the test-level failure is:

 AssertionError: Repro failed on cuda:0, cuda:1:

 Traceback (most recent call last):
   File "/root/transformers/tests/generation/test_continuous_batching_llama32_nogil_repro.py", line 182, in worker
     self.assertEqual(output.status, RequestStatus.FINISHED)
 AssertionError: <RequestStatus.FAILED: 4> != <RequestStatus.FINISHED: 3>

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Mar 23, 2026

I think we can add the test with the right environment variable and multiple GPUs using a decorator.
I am happy to just have this test, not even in the CI, to figure out the issue or have a MRP on my side.
I think we can drop the current test, which only tests that the "capture-error-mode" is the right one, because it is passed as a kwargs, so it cannot really fail.

@Qubitium
Copy link
Copy Markdown
Contributor Author

I think we can add the test with the right environment variable and multiple GPUs using a decorator. I am happy to just have this test, not even in the CI, to figure out the issue or have a MRP on my side. I think we can drop the current test, which only tests that the "capture-error-mode" is the right one, because it is passed as a kwargs, so it cannot really fail.

I have remove the error_mode set test and uploaded the verified script with comments. Please modify the unit test the appropriate model path and decorators so HF ci system can do it's magic, I have no idea how.

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Mar 23, 2026

@Qubitium would it be ok with you to just merge the changes to src/transformers/generation/continuous_batching/continuous_api.py in this PR and I will add the main test in a subsequent PR? I have one in the works refactoring CB tests (#44858)

@Qubitium
Copy link
Copy Markdown
Contributor Author

Qubitium commented Mar 23, 2026

@Qubitium would it be ok with you to just merge the changes to src/transformers/generation/continuous_batching/continuous_api.py in this PR and I will add the main test in a subsequent PR? I have one in the works refactoring CB tests (#44858)

Done. Unit test removed.

@Qubitium
Copy link
Copy Markdown
Contributor Author

There is related issue, but outside the scope of this PR, where Continuos Batch generation can bind to the wrong cuda context.

The current deisgn assumes the traditional one-process, one model approach so the following generation code work without issue. Things breaks down in one process, mult-threads, multi-model instance. At that point, the following piece of code will deadlock/stall due to wrong cuda ctx.

The fix is easy if each model is only using 1 gpu. Just set the ctx to the model's device. But if 2 gpus is bound to a model, then I am not sure if there is a torch api that can actually assign a range device ctx to a thread?

So a potential partial fix to the following block of code, is to check if the model has weights on count(device) > 1 or tp > 1, if not true, we can go the safely wrap the following code in a single cuda:device ctx and this allows multi model, mult-threads, single gpu per model to execute correct using continuous batching. For tp > 1 and mult-gpu devices, warn users via documentation that this piece of code is unsuitable for nogil (regardless of thread saftey) and multi-model instances. Let me know if this partial fix is worthwhile for me to PR. in the meantime, I will just monkeypatch from my end.

    def _run_generation_loop(self) -> None:
        """Main processing loop running in the background thread."""
        batch_processor: ContinuousBatchProcessor | None = None
        try:
            t0 = perf_counter()
            paged_attention_cache = PagedAttentionCache(
                self.model.config,
                self.continuous_batching_config,
                self.model.device,
                self.model.dtype,
                tp_size=getattr(self.model, "_tp_size", None),  # Use model's actual TP setting
            )
            self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing  # update the approximation
            logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")

            scheduler = SCHEDULER_MAPPING.get(self.continuous_batching_config.scheduler, None)
            if scheduler is None:
                logger.warning(
                    f"Scheduler '{self.continuous_batching_config.scheduler}' not found. Defaulting to FIFO."
                )
                scheduler = FIFOScheduler

            t1 = perf_counter()
            batch_processor = ContinuousBatchProcessor(
                cache=paged_attention_cache,
                config=self.model.config,

My monkey patch wrapper:

def _patch_continuous_batching_manager_cuda_context_once(ContinuousBatchingManager: Any) -> None:
   run_generation_loop = getattr(ContinuousBatchingManager, "_run_generation_loop", None)
   if not callable(run_generation_loop):
       return

   with _CONTINUOUS_BATCHING_CUDA_CONTEXT_PATCH_LOCK:
       current = getattr(ContinuousBatchingManager, "_run_generation_loop", None)
       # Another session may have patched the class first; avoid wrapping the same method twice.
       if not callable(current) or getattr(current, "__evalution_cuda_context_patch__", False):
           return

       @wraps(current)
       def _wrapped_run_generation_loop(self: Any, *args: Any, **kwargs: Any) -> Any:
           import torch

           model_device = getattr(getattr(self, "model", None), "device", None)
           # The manager loop is the first code that runs on the background generation thread.
           # Enter the model's CUDA device here so any current-device CUDA calls made deeper in
           # transformers resolve against the manager's model instead of whatever device happened
           # to be current on that thread previously.
           maybe_device = (
               torch.cuda.device(model_device)
               if getattr(model_device, "type", None) == "cuda"
               else nullcontext()
           )
           with maybe_device:
               return current(self, *args, **kwargs)

       # Mark the wrapper so repeated manager construction in the same process stays idempotent.
       _wrapped_run_generation_loop.__evalution_cuda_context_patch__ = True
       ContinuousBatchingManager._run_generation_loop = _wrapped_run_generation_loop

@remi-or Please when you have time, check this too. This part blindly assumes the execution ctx to have all gpu ctx when it may only have cuda:0. Not sure what is the best way to fix this. I know how to fix for 1 model with 1 gpu but what happens when it needs to hold ctx for cuda:0 + cuda:1 in tensor parallel and there is 4 gpu and the 2nd set of gpus cuda:2, cuda:3 is for another thread/model.

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44924&sha=f5accd

@remi-or
Copy link
Copy Markdown
Collaborator

remi-or commented Mar 23, 2026

Good point on your last comment. CB has been developed with VLLM-like performance on a single GPU so far, so multi-GPU has been on the back burner for a while now. But these are questions we want to tackle soon. Thanks!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker merged commit dda5468 into huggingface:main Mar 23, 2026
26 of 28 checks passed
@Qubitium Qubitium deleted the continuos-batching-paged-attention-threads branch March 23, 2026 16:16
@Qubitium Qubitium restored the continuos-batching-paged-attention-threads branch March 24, 2026 05:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants