Skip to content

Configure FSDP to keep module params#12074

Merged
maanug-nv merged 1 commit intoNVIDIA-NeMo:mainfrom
timmoon10:debug-te-fsdp
Mar 6, 2025
Merged

Configure FSDP to keep module params#12074
maanug-nv merged 1 commit intoNVIDIA-NeMo:mainfrom
timmoon10:debug-te-fsdp

Conversation

@timmoon10
Copy link
Collaborator

What does this PR do ?

This PR aims to fix a bug introduced in Transformer Engine 1.14 when running with FSDP.

Transformer Engine 1.14 changed its LayerNorm and RMSNorm implementation to use a prototype operation-based API (see NVIDIA/TransformerEngine#1033). This is a generic API intended for kernel fusion and its forward pass requires calling torch.nn.Module.parameters to interface with PyTorch's autograd infrastructure. However, the default behavior of PyTorch FSDP is to remove all module parameters and manage them via FlatParameters. Since TE modules couldn't access their parameters, we got the following error:

  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 757, in fwd_bwd_step
    losses_reduced_per_micro_batch = fwd_bwd_function(
                                     ^^^^^^^^^^^^^^^^^
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 467, in forward_backward_no_pipelining
    backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 366, in backward_step
    custom_backward(output_tensor[0], output_tensor_grad[0])
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 150, in custom_backward
    Variable._execution_engine.run_backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 600, in wrapper
    outputs = fn(ctx, *args)
              ^^^^^^^^^^^^^^
  File "/src/transformerengine/transformer_engine/pytorch/ops/fuser.py", line 267, in backward
    raise RuntimeError(
RuntimeError: Expected op 0 to generate 0 param grads, but got 2

This PR changes the FSDP configuration with use_orig_params=True so that TE modules can still access their parameters.

Collection: NLP

Changelog

  • Configure FSDP to keep module params

Usage

Run GPT, e.g. with the config at https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml.

Some relevant options:

  • Transformer Engine: model.transformer_engine=True
  • FSDP: model.fsdp=True, fsdp_sharding_strategy=full

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

@timmoon10 timmoon10 added bug Something isn't working NLP labels Feb 6, 2025
@github-actions
Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

Needed to avoid bug with Transformer Engine LayerNorm, which needs to access module parameters.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Collaborator

@maanug-nv maanug-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@maanug-nv maanug-nv enabled auto-merge (squash) March 5, 2025 23:57
@maanug-nv maanug-nv merged commit 028d43f into NVIDIA-NeMo:main Mar 6, 2025
493 of 503 checks passed
ko3n1g pushed a commit that referenced this pull request Mar 7, 2025
Needed to avoid bug with Transformer Engine LayerNorm, which needs to access module parameters.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
chtruong814 pushed a commit that referenced this pull request Mar 7, 2025
Needed to avoid bug with Transformer Engine LayerNorm, which needs to access module parameters.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
BoxiangW pushed a commit that referenced this pull request Mar 10, 2025
Needed to avoid bug with Transformer Engine LayerNorm, which needs to access module parameters.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 deleted the debug-te-fsdp branch May 7, 2025 17:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments