[PyTorch] Refactor FP8 workspaces in linear modules#820
[PyTorch] Refactor FP8 workspaces in linear modules#820timmoon10 merged 22 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Tensor base class functions in Float8Tensor have significant overhead. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Previous caching behavior (always fill cache) incorrectly filled cache during CUDA graph warmup steps. Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
ONNX FP8 cast ops assumed that FP8 scales were created during model export (i.e. not initialized during training). Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
|
This PR has grown in scope as I've identified bugs:
|
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
|
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
Work around ONNX test failures by filling FP8 scale tensors instead of copying into them. Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch |
| ) | ||
| data = data.cuda() | ||
| if not data.is_cuda: | ||
| data = data.cuda() |
There was a problem hiding this comment.
Is this not a noop if tensor is already on device?
There was a problem hiding this comment.
Yep, tensor conversion functions (to, cuda, contiguous, float) will be no-ops if possible. I prefer being as paranoid as possible until CPU overheads start becoming a problem.
There was a problem hiding this comment.
Ah, I missed the "not" in your question. The conversion functions add some CPU overhead, so this is a microoptimization.
|
@timmoon10 Have you verified identical numerics with this change? |
|
For testing CUDA graphs with FP8 caching, did you use the |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
With the tests and bugfixes in #869, this PR seems to handle |
|
/te-ci pytorch |
ksivaman
left a comment
There was a problem hiding this comment.
Approving given numerics for all cases remain identical in e2e tests
|
Training GPT for 100 steps (175B params, TP=2, PP=4), I don't see significant differences in the loss curves with and without this PR. I think this is ready to go. |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Add basic infrastructure for Sequential module Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add linear op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add FP8 support in linear op Runs, but need to validate. Runtime errors with non-FP8 params and FP8 compute, or FP8 params and non-FP8 compute. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add reshape op and unit test Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add bias op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add unfused linear op Test does not pass with FP8. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug unfused linear op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add test for linear+bias op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add separate abstract classes for unfused and fused ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * Consolidate unfused ops in submodule Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add linear-bias fused op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use fused cast-transpose in linear ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * Disable GEMM+bias fusion with FP32 activations Not supported by cuBLAS. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add parallel unit test for unfused linear op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Refactor parallel tests to reduce job launches Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add all-reduce, all-gather, and reduce-scatter ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * Remove unused file Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug multi-GPU FP8 test Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add support for FP8 scale updates Still need to implement amax reductions. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add license boilerplate Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fuse GEMM+bias in row TP Add documentation for unfused ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * Rename pipeline to fuser Expand documentation Signed-off-by: Tim Moon <tmoon@nvidia.com> * Tweak documentation Signed-off-by: Tim Moon <tmoon@nvidia.com> * Preserve cached FP8 transpose between ops Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add option for fused wgrad accumulation Signed-off-by: Tim Moon <tmoon@nvidia.com> * Directly output FP8 from linear if needed Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix cuDNN front-end commit Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use updated FP8 tensor API for transpose caching Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use updated API for FP8 scale updates Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add tests for non-default FP8 recipes Signed-off-by: Tim Moon <tmoon@nvidia.com> * Rename UnfusedOperation to BasicOperation Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add unit test to check amax reduction with fusable op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Operator autograd state no longer needs to be initialized Signed-off-by: Tim Moon <tmoon@nvidia.com> * Initial functional implementation of linear op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug fused linear+bias op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Remove autograd context from functional linear impl Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use functional linear impl in fused linear+bias op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Rename subdirectory from "fuser" to "ops" Avoid confusion with kernel fusers and graph compilers. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update with Float8Tensor changes in #820 Signed-off-by: Tim Moon <tmoon@nvidia.com> * Remove unnecessary CPU overheads Signed-off-by: Tim Moon <tmoon@nvidia.com> * Correctly pass FP8 metadata from next op Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix linter errors Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add convenience functions to manipulate Sequential class Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update name of PyTorch extensions module Signed-off-by: Tim Moon <tmoon@nvidia.com> * Clear saved tensor data in linear op after bprop Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix Pylint error Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update name of PyTorch extensions module Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix test name in QA script Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update name of PyTorch extensions module Signed-off-by: Tim Moon <tmoon@nvidia.com> * Run distributed tests even when only 1 GPU is available Signed-off-by: Tim Moon <tmoon@nvidia.com> * Only run distributed tests with 2 GPUs if there are >=2 GPUs Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review suggestions from @sudhakarsingh27 and @ksivaman Fix spelling of "fusible". Avoid "input" name in internal APIs. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Update transformer_engine/pytorch/ops/__init__.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This PR refactors the logic for FP8 weight workspaces in
te.Linear,te.LayerNormLinear, andte.LayerNormMLP. The existing logic is somewhat convoluted since it was designed to pass around raw UINT8 buffers andFloat8Tensorsupport was kludged in #452. For example, whente.Linearhas FP32 params, it maintains two workspaceFloat8Tensors (FP8 weight and transpose) and in each forward pass it will extract out the buffers to create another temporaryFloat8Tensor. This PR streamlines the process so it will just maintain a single workspaceFloat8Tensor.Motivations:
Float8Tensorhas a private copy of the FP8 scale-inverse won't be affected by scale updates until its values are updated.Float8Tensorconstructor requires a CUDA kernel launch to initialize the FP8 scale-inverse, so creating unnecessaryFloat8Tensors adds non-trivial CPU overhead. Benchmarking the forward pass of smallte.Linears, this PR gives a 1.12x speedup.Float8Tensorinternally handle things like FP8 casting and interacting withfp8_meta.