Conversation
46b509a to
bd7fd0a
Compare
5d5e52c to
8cb93ff
Compare
f4c8b6f to
374867a
Compare
| def _reset_caches(self) -> None: | ||
| """Reset cached values | ||
|
|
||
| Should be called after any in-place operation. | ||
|
|
||
| """ | ||
| self._transpose = None |
There was a problem hiding this comment.
Removing the automatic cache clearing makes using the transpose cache a much more manual and dangerous process. Consider something like:
matmul(x, w.transpose(0, 1))
w -= learning_rate * w.grad
matmul(x, w.transpose(0, 1))Previously we could just set update_cache="lazy". Now there needs to be manual logic to figure out the microbatch step, or else it will provide the stale values.
There was a problem hiding this comment.
In this example, caching is not used, so a fresh transpose will be computed each time.
There was a problem hiding this comment.
If caching is used, it is reasonable to expect the user to know when to reuse a cached value and when to force recompute. This is consistent with our design of is_first_microbatch argument to the forward for module APIs.
There was a problem hiding this comment.
Note: we use 2 args cache and update_cache to support this logic.
There was a problem hiding this comment.
I think we're overfitting to the Linear weight use-case. For example, in #707 I want to pass Float8Tensors between ops as inputs or dgrads:
class DbiasCastTranspose:
def backward(self, dy):
db = dy.sum(dim=0)
dx = cast_transpose(dy) # Creates Float8Tensor with transpose cache
return dx, db
class FP8Linear: # Part of FP8 attention
def backward(self, dy):
if not isinstance(dy, Float8Tensor):
dy = Float8Tensor.to_float8(dy)
dx = Float8Tensor(...) # No transpose cache
fp8_gemm(w.transpose()._data, dy.transpose()._data, out=dx._data)
dw = fp8_gemm(x, dy)
return dx, dw
FP8Linear has no idea where its input came from. Maybe it's from DbiasCastTranspose (Float8Tensor with cached transpose), FP8Linear (Float8Tensor without cached transpose), or a non-FP8 op. Our current approach with lazy transpose caching gives us a lot of flexibility and I think we should abandon it only when really necessary.
I suppose this is not precisely relevant since it doesn't involve in-place operations, but a more general statement about the design of Float8Tensor.
transformer_engine/common/include/transformer_engine/cast_transpose_noop.h
Show resolved
Hide resolved
d0aa61c to
bb5b4d6
Compare
|
/te-ci |
|
/te-ci pytorch |
|
/te-ci pytorch |
eff5d27 to
32e070c
Compare
|
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Charlene Yang <charleney@nvidia.com>
db6a812 to
31dc133
Compare
|
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
9944150 to
3c50a17
Compare
|
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
* FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Charlene Yang <charleney@nvidia.com> * Fix numerics Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * exclude torch compile from numerics tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * More numerics fixes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix CI Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * rm fusion from unfused path Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Charlene Yang <charleney@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
This PR adds the following features (high-level):
make_graphed_callablesAPI similar to the PyTorch API with some additional arguments for FP8 usage. Support for fp8 weight caching via existingis_first_microbatchargument is also retained.Float8Tensorthat makes the transposes persistent for graph capture. Also fixes use cases for the vanilla optimizers (non fp8-distopt).