Skip to content

[PyTorch] Refactor FP8 workspaces in linear modules#820

Merged
timmoon10 merged 22 commits intoNVIDIA:mainfrom
timmoon10:fp8-workspace-refactor
May 30, 2024
Merged

[PyTorch] Refactor FP8 workspaces in linear modules#820
timmoon10 merged 22 commits intoNVIDIA:mainfrom
timmoon10:fp8-workspace-refactor

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Apr 27, 2024

This PR refactors the logic for FP8 weight workspaces in te.Linear, te.LayerNormLinear, and te.LayerNormMLP. The existing logic is somewhat convoluted since it was designed to pass around raw UINT8 buffers and Float8Tensor support was kludged in #452. For example, when te.Linear has FP32 params, it maintains two workspace Float8Tensors (FP8 weight and transpose) and in each forward pass it will extract out the buffers to create another temporary Float8Tensor. This PR streamlines the process so it will just maintain a single workspace Float8Tensor.

Motivations:

  • This fixes an FP8 recipe bug introduced in [PyTorch] cuda graph support #575 (see Handle the scaling factor when amax is too tiny that leads to an infinite scale #786 (review)). The FP8 scale update kernel updates the scales for weights in every forward pass, even ones where the FP8 weights are not updated, so we can run into situations where the FP8 scales don't match the FP8 weights. This PR fixes this by taking advantage of the fact that Float8Tensor has a private copy of the FP8 scale-inverse won't be affected by scale updates until its values are updated.
  • FP8 compute can sometimes result in performance degradation due to CPU overheads (see CPU Overhead of te.Linear FP8 Layers #761). The Float8Tensor constructor requires a CUDA kernel launch to initialize the FP8 scale-inverse, so creating unnecessary Float8Tensors adds non-trivial CPU overhead. Benchmarking the forward pass of small te.Linears, this PR gives a 1.12x speedup.
  • I find the logic in this PR a bit easier to reason about, although I'd appreciate feedback. It feels nicer to let Float8Tensor internally handle things like FP8 casting and interacting with fp8_meta.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Tensor base class functions in Float8Tensor have significant overhead.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added bug Something isn't working enhancement New feature or request labels Apr 27, 2024
@timmoon10 timmoon10 requested review from ksivaman and ptrendx April 27, 2024 01:44
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

timmoon10 added 3 commits May 14, 2024 21:40
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Previous caching behavior (always fill cache) incorrectly filled cache during CUDA graph warmup steps.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

timmoon10 added 3 commits May 15, 2024 21:32
Signed-off-by: Tim Moon <tmoon@nvidia.com>
ONNX FP8 cast ops assumed that FP8 scales were created during model export (i.e. not initialized during training).

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

timmoon10 commented May 16, 2024

This PR has grown in scope as I've identified bugs:

  • Our current API for transpose caching in Float8Tensor doesn't work well with CUDA graphs. At the point we perform FP8 GEMMs, the FP8 weights may or may not have the transpose cache filled (e.g. it's not filled when the model is newly created). The only way to access the cache is to call transpose_2d(cache=True), which fills the cache. But if you capture a CUDA graph with is_first_microbatch=None, this means that the cache is filled during the warmup steps and you never actually capture the transpose kernel. The easiest fix is to modify Float8Tensor to support lazy transpose caching (see [PyTorch] cuda graph support #575 (comment), [PyTorch] cuda graph support #575 (review), and [PyTorch] Fix Float8Tensor transpose caching in #575 #735).
  • The ONNX export tests were failing because they assume the FP8 scales can be represented with constant operations, which requires that the scales are initialized during the ONNX export. The ONNX export tests failed because we copy an FP8 scale with Tensor.copy_, which is translated into the ONNX expand operation, I think to handle array broadcasting. This breaks some assumptions that the FP8 scales can be represented as ONNX constant operations. The fix is to use Tensor.fill_ on the FP8 scales instead of Tensor.copy_ during the ONNX export process. See [PyTorch] Handle non-constant FP8 scales in ONNX export #861 (comment).

@timmoon10 timmoon10 marked this pull request as ready for review May 16, 2024 04:26
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

timmoon10 and others added 2 commits May 22, 2024 21:31
Work around ONNX test failures by filling FP8 scale tensors instead of copying into them.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

)
data = data.cuda()
if not data.is_cuda:
data = data.cuda()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not a noop if tensor is already on device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, tensor conversion functions (to, cuda, contiguous, float) will be no-ops if possible. I prefer being as paranoid as possible until CPU overheads start becoming a problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed the "not" in your question. The conversion functions add some CPU overhead, so this is a microoptimization.

@ksivaman
Copy link
Member

@timmoon10 Have you verified identical numerics with this change?

@ksivaman
Copy link
Member

For testing CUDA graphs with FP8 caching, did you use the noop_flag in transpose and the fp8_weight_caching flag in make_graphed_callables?

timmoon10 added 2 commits May 28, 2024 19:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

With the tests and bugfixes in #869, this PR seems to handle make_graphed_callables with fp8_weight_caching=True correctly.

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving given numerics for all cases remain identical in e2e tests

@timmoon10
Copy link
Collaborator Author

Training GPT for 100 steps (175B params, TP=2, PP=4), I don't see significant differences in the loss curves with and without this PR. I think this is ready to go.

@timmoon10 timmoon10 merged commit b1a0e0a into NVIDIA:main May 30, 2024
timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request May 30, 2024
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10 added a commit that referenced this pull request Jul 9, 2024
* Add basic infrastructure for Sequential module

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add linear op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add FP8 support in linear op

Runs, but need to validate. Runtime errors with non-FP8 params and FP8 compute, or FP8 params and non-FP8 compute.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add reshape op and unit test

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add bias op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add unfused linear op

Test does not pass with FP8.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug unfused linear op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add test for linear+bias op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add separate abstract classes for unfused and fused ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Consolidate unfused ops in submodule

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add linear-bias fused op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use fused cast-transpose in linear ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Disable GEMM+bias fusion with FP32 activations

Not supported by cuBLAS.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add parallel unit test for unfused linear op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Refactor parallel tests to reduce job launches

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add all-reduce, all-gather, and reduce-scatter ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove unused file

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug multi-GPU FP8 test

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add support for FP8 scale updates

Still need to implement amax reductions.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add license boilerplate

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fuse GEMM+bias in row TP

Add documentation for unfused ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Rename pipeline to fuser

Expand documentation

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Tweak documentation

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Preserve cached FP8 transpose between ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add option for fused wgrad accumulation

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Directly output FP8 from linear if needed

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix cuDNN front-end commit

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use updated FP8 tensor API for transpose caching

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use updated API for FP8 scale updates

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tests for non-default FP8 recipes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Rename UnfusedOperation to BasicOperation

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add unit test to check amax reduction with fusable op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Operator autograd state no longer needs to be initialized

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Initial functional implementation of linear op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug fused linear+bias op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove autograd context from functional linear impl

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use functional linear impl in fused linear+bias op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Rename subdirectory from "fuser" to "ops"

Avoid confusion with kernel fusers and graph compilers.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update with Float8Tensor changes in #820

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Remove unnecessary CPU overheads

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Correctly pass FP8 metadata from next op

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix linter errors

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add convenience functions to manipulate Sequential class

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Clear saved tensor data in linear op after bprop

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix Pylint error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix test name in QA script

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update name of PyTorch extensions module

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Run distributed tests even when only 1 GPU is available

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Only run distributed tests with 2 GPUs if there are >=2 GPUs

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Review suggestions from @sudhakarsingh27 and @ksivaman

Fix spelling of "fusible". Avoid "input" name in internal APIs.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Update transformer_engine/pytorch/ops/__init__.py

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants