Skip to content

[graph_trainer] Fix CUDAGraph warmup to stay on current stream#2922

Closed
bobrenjc93 wants to merge 2 commits intogh/bobrenjc93/40/basefrom
gh/bobrenjc93/40/head
Closed

[graph_trainer] Fix CUDAGraph warmup to stay on current stream#2922
bobrenjc93 wants to merge 2 commits intogh/bobrenjc93/40/basefrom
gh/bobrenjc93/40/head

Conversation

@bobrenjc93
Copy link
Copy Markdown
Contributor

@bobrenjc93 bobrenjc93 commented Apr 9, 2026

Stack from ghstack (oldest at bottom):

The warmup phase was running on _cg_manager.stream via
_use_cuda_memory_pool_manager, which caused NCCL collectives to
execute on a non-default stream. On multi-node IB/RoCE setups this
leads to illegal memory access errors because NCCL expects to run
on the stream it was initialized with.

Fix by passing torch.cuda.current_stream() to
_use_cuda_memory_pool_manager instead of _cg_manager.stream. This
keeps NCCL collectives on the correct stream while still directing
warmup allocations into the graph memory pool (avoiding fragmentation
between eager and graph pools).

Also remove a redundant torch.cuda.synchronize() before graph
recording — torch.cuda.graph.__enter__ already performs a full
device sync. Add capture_error_mode="thread_local" for better
error reporting during graph capture.

[ghstack-poisoned]
[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request Apr 10, 2026
The warmup phase was running on _cg_manager.stream via
_use_cuda_memory_pool_manager, which caused NCCL collectives to
execute on a non-default stream. On multi-node IB/RoCE setups this
leads to illegal memory access errors because NCCL expects to run
on the stream it was initialized with.

Fix by passing torch.cuda.current_stream() to
_use_cuda_memory_pool_manager instead of _cg_manager.stream. This
keeps NCCL collectives on the correct stream while still directing
warmup allocations into the graph memory pool (avoiding
fragmentation between eager and graph pools).

Also remove a redundant torch.cuda.synchronize() before graph
recording — torch.cuda.graph.__enter__ already performs a full
device sync. Add capture_error_mode="thread_local" for better
error reporting during graph capture.

ghstack-source-id: 18aa00a
Pull-Request: #2922
@bobrenjc93
Copy link
Copy Markdown
Contributor Author

@claude

@claude
Copy link
Copy Markdown

claude bot commented Apr 10, 2026

Claude finished @bobrenjc93's task in 3m 9s —— View job


Review of PR #2922

  • Gather context and read changed files
  • Review the diff against base branch
  • Post review feedback

Overall: This is a well-motivated fix. The net change (vs. base at HEAD~2) is small and focused — switching the warmup stream parameter from _cg_manager.stream to torch.cuda.current_stream(), plus adding capture_error_mode="thread_local" to graph capture. The comments explaining the rationale are clear and helpful.

Changes look correct. Here are some detailed observations:

Warmup stream fix (cudagraph.py:192-194)

The core fix — passing torch.cuda.current_stream() instead of _cg_manager.stream to _use_cuda_memory_pool_manager — makes sense. _use_cuda_memory_pool_manager directs allocations into the graph memory pool but the stream argument controls which stream the operations actually execute on. Using _cg_manager.stream (a non-default stream) during warmup would cause NCCL collectives to run on a stream they weren't initialized with, which explains the illegal memory access errors on multi-node setups.

One thing to verify: torch.cuda.current_stream() is evaluated at call time, so this will correctly capture whatever stream is active when warmup runs. This is the right behavior — if the user is on the default stream (typical), warmup stays on the default stream.

Removed torch.cuda.synchronize() before graph recording (cudagraph.py:204-209)

The PR description says torch.cuda.graph.__enter__ already performs a full device sync. This is accurate based on the PyTorch implementation — the torch.cuda.graph context manager synchronizes the stream before starting capture. Removing the redundant sync is a clean improvement.

capture_error_mode="thread_local" (cudagraph.py:209)

This is a good addition for debugging. "thread_local" restricts error checking to the current thread, which gives more precise error attribution during graph capture compared to the default "global" mode. This was already applied to the dummy graph at line 59, so it's consistent to use it here too.

Minor notes

  • The intermediate commit (ffc45b0) removed _use_cuda_memory_pool_manager entirely and replaced warmup with bare synchronize() + self._runnable(*args) + synchronize(). The final commit (9c5b2b6) brings the pool manager back but with the correct stream. The net diff is clean — only the stream argument changed and the redundant sync/comment removed. The two-commit history is a bit noisy (the intermediate commit is effectively reverted then improved), but since this is a ghstack PR, the squashed result will be fine.

  • The comment block at lines 187-191 is well-written and explains the "why" clearly. It documents both the memory pool motivation and the stream choice, which will help future readers.

LGTM — the fix is targeted, well-explained, and the risk surface is small.


self._cudagraph,
pool=_cg_manager.graph_pool,
stream=_cg_manager.stream,
capture_error_mode="thread_local",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this line be sufficient? If we don't have the change above in

            with _use_cuda_memory_pool_manager(
                device, _cg_manager.graph_pool, torch.cuda.current_stream()
            ):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so? This just gives better error reporting but the underlying issue is IMAs stemming from the warmup run on the wrong stream (which gets solved with the quoted code).

Is there a reason you think this one line would be sufficient?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. warmup on capture stream has been working with nccl for a while. can we find the root cause why torch.cuda.current_stream() is needed to avoid IMA?
  2. if using torch.cuda.current_stream(), there would be a high memory overhead due to fragmentation between capture stream and default stream in private pool.

@bobrenjc93 bobrenjc93 requested a review from BoyuanFeng April 10, 2026 21:15
@bobrenjc93
Copy link
Copy Markdown
Contributor Author

talked offline, will try and get precise profile data first and then we can figure out best workaround

@bobrenjc93 bobrenjc93 closed this Apr 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants