[graph_trainer] Fix CUDAGraph warmup to stay on current stream#2922
[graph_trainer] Fix CUDAGraph warmup to stay on current stream#2922bobrenjc93 wants to merge 2 commits intogh/bobrenjc93/40/basefrom
Conversation
The warmup phase was running on _cg_manager.stream via _use_cuda_memory_pool_manager, which caused NCCL collectives to execute on a non-default stream. On multi-node IB/RoCE setups this leads to illegal memory access errors because NCCL expects to run on the stream it was initialized with. Fix by passing torch.cuda.current_stream() to _use_cuda_memory_pool_manager instead of _cg_manager.stream. This keeps NCCL collectives on the correct stream while still directing warmup allocations into the graph memory pool (avoiding fragmentation between eager and graph pools). Also remove a redundant torch.cuda.synchronize() before graph recording — torch.cuda.graph.__enter__ already performs a full device sync. Add capture_error_mode="thread_local" for better error reporting during graph capture. ghstack-source-id: 18aa00a Pull-Request: #2922
|
Claude finished @bobrenjc93's task in 3m 9s —— View job Review of PR #2922
Overall: This is a well-motivated fix. The net change (vs. base at Changes look correct. Here are some detailed observations: Warmup stream fix (
|
| self._cudagraph, | ||
| pool=_cg_manager.graph_pool, | ||
| stream=_cg_manager.stream, | ||
| capture_error_mode="thread_local", |
There was a problem hiding this comment.
would this line be sufficient? If we don't have the change above in
with _use_cuda_memory_pool_manager(
device, _cg_manager.graph_pool, torch.cuda.current_stream()
):
There was a problem hiding this comment.
I don't think so? This just gives better error reporting but the underlying issue is IMAs stemming from the warmup run on the wrong stream (which gets solved with the quoted code).
Is there a reason you think this one line would be sufficient?
There was a problem hiding this comment.
- warmup on capture stream has been working with nccl for a while. can we find the root cause why torch.cuda.current_stream() is needed to avoid IMA?
- if using
torch.cuda.current_stream(), there would be a high memory overhead due to fragmentation between capture stream and default stream in private pool.
|
talked offline, will try and get precise profile data first and then we can figure out best workaround |
Stack from ghstack (oldest at bottom):
The warmup phase was running on
_cg_manager.streamvia_use_cuda_memory_pool_manager, which caused NCCL collectives toexecute on a non-default stream. On multi-node IB/RoCE setups this
leads to illegal memory access errors because NCCL expects to run
on the stream it was initialized with.
Fix by passing
torch.cuda.current_stream()to_use_cuda_memory_pool_managerinstead of_cg_manager.stream. Thiskeeps NCCL collectives on the correct stream while still directing
warmup allocations into the graph memory pool (avoiding fragmentation
between eager and graph pools).
Also remove a redundant
torch.cuda.synchronize()before graphrecording —
torch.cuda.graph.__enter__already performs a fulldevice sync. Add
capture_error_mode="thread_local"for bettererror reporting during graph capture.