[PyTorch] Integration test for Megatron-LM#1329
Merged
timmoon10 merged 7 commits intoNVIDIA:mainfrom Nov 21, 2024
Merged
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
13 tasks
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Collaborator
Author
|
Pipeline 20338114 |
This was referenced Nov 13, 2024
Collaborator
Author
|
Pipeline 20444324 is green |
timmoon10
commented
Nov 20, 2024
Comment on lines
+138
to
+142
| if requires_grad != x.requires_grad: | ||
| if requires_grad: | ||
| x.requires_grad_() | ||
| else: | ||
| x = x.detach() |
Collaborator
Author
There was a problem hiding this comment.
This fixes a te.Sequential bug that was exposed by Mcore. When running in eval mode, we want x.requires_grad=False so that the op knows that it doesn't need to prepare for that grad. However, PyTorch sometimes complains if you change a tensor's requires_grad from True to False (i.e. when the tensor is not a leaf in the autograd graph). Detaching the tensor works around this case.
sudhakarsingh27
approved these changes
Nov 20, 2024
| # Check tensor dims | ||
| weight = self.weight | ||
| weight_dims = tuple(weight.size()) | ||
| input_dims = tuple(input_.size()) |
Collaborator
There was a problem hiding this comment.
apparently torch.Size is a subclass of tuple so tuple creation probably not needed
|
|
||
| # Check tensor dims | ||
| weight = self.weight | ||
| weight_dims = tuple(weight.size()) |
Collaborator
There was a problem hiding this comment.
no need to tupleize
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10
added a commit
that referenced
this pull request
Nov 21, 2024
* Handle deprecated `hidden_size` arg in norm modules Signed-off-by: Tim Moon <tmoon@nvidia.com> * Support initializing norm ops on CPU Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add integration test for Megatron-LM Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename Mcore integration test Signed-off-by: Tim Moon <tmoon@nvidia.com> * Handle case in RMSNorm where hidden dim is not provided Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
#1033 broke Megatron-LM's wrappers for the LayerNorm and RMSNorm modules:
hidden_sizearg tonormalized_shapein order to matchtorch.nn.LayerNorm, but Megatron-LM treatshidden_sizeas a kwarg:https://github.com/NVIDIA/Megatron-LM/blob/aded519cfb1de2abf96f36ca059f992294b7876f/megatron/core/extensions/transformer_engine.py#L65.
This PR adds logic to handle the
hidden_sizearg and print a deprecation warning.To help detect these issues in the future, I've also added an integration test that runs Megatron-LM to train a very small GPT model.
Type of change
Changes
hidden_sizearg in LayerNorm and RMSNorm modulesChecklist: