[PyTorch] Debug CUDA graph support with operation-based API#1117
[PyTorch] Debug CUDA graph support with operation-based API#1117timmoon10 merged 13 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
transformer_engine/pytorch/ops/op.py
Outdated
| if fp8_recipe is None: | ||
| fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
| if fp8_recipe is None: | ||
| fp8_recipe = get_default_fp8_recipe() |
There was a problem hiding this comment.
Hmmmm, this second if looks like logic that should be inside get_fp8_recipe in the FP8GlobalStateManager.
There was a problem hiding this comment.
Also, since this is an internal function, couldn't we just always ask for a valid recipe here and just deal with getting it int the caller?
There was a problem hiding this comment.
This case shouldn't happen in any of our current use-cases (FP8GlobalStateManager.get_fp8_recipe() is set within fp8_autocast, fp8_recipe is provided within make_graphed_callables), but it seems delicate to rely on that assumption.
| if curr_len == amax_history_len: | ||
| continue | ||
|
|
||
| # Reallocate amax history |
There was a problem hiding this comment.
Could this be its own function?
There was a problem hiding this comment.
I've tried to keep this logic similar to how it's handled in the modules:
I think it would be nice to consolidate this logic in
fp8.py and reuse it for both modules and operations, but that's probably best done in a pure refactor PR.
| pad=(0, 0, 0, amax_history_len - curr_len), | ||
| ) | ||
|
|
||
| # Update global buffers for amax reductions |
There was a problem hiding this comment.
This does not look like graph specific thing - was the lack of this in the previous code a bug?
There was a problem hiding this comment.
Yep, if the amax history length changes then I don't expect amax reductions to be handled correctly.
Return default recipe from FP8GlobalStateManager.get_fp8_recipe if needed. Expand error message when failing to load FP8 state after capturing CUDA graph. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
|
/te-ci pytorch |
|
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Description
This PR debugs CUDA graph support with the operation-based API (see #707). The CUDA graph logic is similar to the module-based API.
Type of change
Changes
Checklist: