Skip to content

Enable reuse of dummy wgrad tensor#1651

Merged
ksivaman merged 5 commits intoNVIDIA:mainfrom
vasunvidia:dummy_wgrads
Apr 8, 2025
Merged

Enable reuse of dummy wgrad tensor#1651
ksivaman merged 5 commits intoNVIDIA:mainfrom
vasunvidia:dummy_wgrads

Conversation

@vasunvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ksivaman and others added 5 commits April 7, 2025 17:57
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
@ksivaman
Copy link
Member

ksivaman commented Apr 7, 2025

/te-ci pytorch L0 L1

Comment on lines 691 to 701
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could clean this up:

Suggested change
if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=True,
)
else:
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
)
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
zero=not getattr(weight, "zero_out_wgrad", False),
)

We could do a similar change in LayerNormLinear.

return _multi_stream_cublas_workspace


def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be simplified with lru_cache.

send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0"))
Copy link
Collaborator

@timmoon10 timmoon10 Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for this test change? It seems orthogonal to the functional changes.

@ksivaman ksivaman merged commit ba5dc5d into NVIDIA:main Apr 8, 2025
11 of 12 checks passed
wdykas pushed a commit to wdykas/TransformerEngine that referenced this pull request Apr 14, 2025
* Use dummy wgrads for lower memory consumption

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix to avoid sharing gradients.

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Disable automatic use of batch_p2p_comm for CP2

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Change weight to origin_weight for LN_LINEAR

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Peter Dykas <wdykas@nvidia.com>
ptrendx pushed a commit that referenced this pull request May 1, 2025
* Use dummy wgrads for lower memory consumption

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix to avoid sharing gradients.

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Disable automatic use of batch_p2p_comm for CP2

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* Change weight to origin_weight for LN_LINEAR

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants