Skip to content

Comments

[Megatron-FSDP] Add dtype customization to Megatron-FSDP.#3067

Open
cspades wants to merge 11 commits intoNVIDIA:mainfrom
cspades:cye/mfsdp-custom-dtype
Open

[Megatron-FSDP] Add dtype customization to Megatron-FSDP.#3067
cspades wants to merge 11 commits intoNVIDIA:mainfrom
cspades:cye/mfsdp-custom-dtype

Conversation

@cspades
Copy link
Member

@cspades cspades commented Jan 24, 2026

What does this PR do ?

  • Add customization options for main_params_dtype, main_grads_dtype, and grad_comm_dtype to Megatron-FSDP.
    • grad_accum_dtype (high-precision gradient reduction) will be handled by NCCL UBR / SymMem for v2.27+.
  • Various bug-fixes, refactors, documentation updates, commentary, etc.

All planned performance benchmarks completed with no bugs. Ready for expert & final review!

Details

Mixed-Precision Support (megatron_fsdp.MixedPrecisionPolicy)

  • main_params_dtype / --megatron-fsdp-main-params-dtype (🍀 NEW! 🍀) and main_grads_dtype / --megatron-fsdp-main-grads-dtype (🍀 NEW! 🍀) are simple generalizations of preserve_fp32_weights (⛔ DEPRECATED ⛔) and grad_reduce_in_fp32.
    • If not specified, the model weight buffer or model parameters become the main weights in that order of succession, and the main gradient buffer's data-type will be symmetrical to the model compute weight data-type.
  • grad_comm_dtype / --megatron-fsdp-grad-comm-dtype (🍀 NEW! 🍀) controls the data-type used for gradient communication (all-reduce & reduce-scatter).
    • If main_grads_dtype is not equivalent to grad_comm_dtype, a communication bucket with the communication data-type will be allocated. Otherwise, and if not specified, the main_grads_dtype will be the communication data-type.

Megatron-FSDP Gradient Lifecycle

To summarize the gradient pipeline of Megatron-FSDP for the uninitiated:

# Gradient Memory Lifecycle
# (...) = HFSDP pre-optimizer steps only!
param.grad -> DP-Shard Grad Alloc -> Reduce -> (Wait -> Free -> DP-Outer Grad Alloc -> Reduce) -> Wait -> Free
              ^ These are communication buffers to hold un-sharded or partially-sharded gradients.
  • Autograd produces the raw, un-reduced model gradient.
  • Megatron-FSDP's gradient buffer allocates a communication bucket to temporarily hold the un-reduced gradient.
    • 🍀 On main, this communication bucket matches the main gradient buffer data-type. So we cannot have low-precision communication buckets with high-precision main gradients.
    • Now, this is controllable with grad_comm_dtype, to support low-precision communication and high-precision reduction with NCCL (v2.27+).
  • The raw un-reduced gradient is either copied (sharded gradients) or accumulated (un-sharded gradients) into the allocated gradient bucket or main gradient buffer.
  • The gradient communication bucket is retrieved, and passed to the reduce-scatter or all-reduce collective. Accumulation is performed via type-promotion with respect to the main_grads_dtype, typically FP32.
    • For no_shard and optim, this is an local all-reduce or reduce-scatter that can only be called once per optimization cycle to avoid corrupt gradients.
      • This is why we do not immediately allocate a communication buffer for these two cases (and require an allocation per-unit before the collective for custom communication data-type), because the no_shard and optim sharding strategies definitively do not permit a second un-sharded memory allocation in order to maintain both communication and accumulation buffers for the gradient (one for BF16 communication, another for FP32 accumulation) until we finally perform the only DP-reduction right before the optimization step. Thus, we temporarily allocate / deallocate a BF16 communication buffer right before gradient reduction, while persistently allocating an FP32 main gradient bucket.
    • For optim_grads and optim_grads_params, this is a reduce-scatter into the allocated communication bucket, and shards of the result are accumulated into the main gradient buffer. Because we reduce every layer of every step, we only persistently hold onto a reduced and accumulated shard of the gradient.
  • All previous gradient collectives are synchronized if overlapped, gradient shards are attached to the parameter shards of the optimizer state, the model parameters are re-referenced to the distributed optimizer state parameters, and the distributed optimizer step is performed.
  • Finally, the main gradient buffer is zero'd out, which establishes a clean slate for subsequent reduction and accumulation, and the optimized main weights are installed into Megatron-FSDP's compute weight buffer.

🚨 Bug Fixes 🚨

  • optim had corrupted gradients, where the main gradient would be reduce-scattered into a temporary shard, but the reduced shard would be accumulated back into the source main gradient shard (without zero'ing the buffer), leading to duplicate gradients.
    • Fixed by adding copy and += cases to the DP-Shard gradient reduction.
    • For example, with DP-Shard=2, and only 1 accumulation / optimization step for simplicity with (...) representing the reduced gradient and gN representing the pre-reduce accumulated gradient:
      • Rank 0 Expected: (g1 + g2)
      • Rank 0 Actual: g1 + (g1 + g2)
      • Without the torch.empty_like temporary shard, the bug would have doubled the gradient when using optim, i.e. (g1 + g2) += (g1 + g2)!
      • Causes main gradient disparity on all DP ranks.
    • With custom data-type buckets, the generalized logic also works correctly, where for optim we copy the reduced gradient shard into the main gradient buffer if a communication buffer was allocated, otherwise the reduce-scatter directly updates the shard of the main gradient buffer. (Same for no_shard as well, but using all-reduce and copying the reduced un-sharded gradient.)
  • Also discovered a broken gradient DP-scaling while working on this PR: [Megatron-FSDP] Fix incorrect gradient scaling target. #3023

Minor Edits

  • Refactored free_bucket_storage() to remove the criteria that only deallocates buckets for sharded buffers and factor out the param.main_grad reset to reset_param_main_grad().
    • fetch_bucket() will only allocate temporary buckets if the data-type is different, or if the buffer is sharded. So there is a loophole where a custom data-type allocation will not be deallocated if the buffer is sharded.
    • Modules that are not FSDP units should not have their buckets deallocated, but this is controlled by the post-forward and post-backward un-shard hooks that call AllGatherPipeline.release_bucket().
    • reset_param_main_grad() only needs to be called when the FSDP gradient buffer on DP-Shard has completed its collectives and installed the reduced gradient in local data.
      • param.main_grad will first point to the unreduced gradient bucket, and then point to the DP-Shard reduced main gradient buffer data (or a custom data-type variant of the aforementioned values).
  • Implemented check_for_nan_in_grad for Megatron-LM (called in start_grad_sync) and report_nan_in_param_grad for fully_shard, which both default to False in MegatronFSDP. report_nan_in_param_grad in particular is an expensive operation that can degrade performance by around 5%, but can be extremely useful for quickly debugging the source of NaNs, whether they come from Megatron-FSDP or user models.
  • Updated and removed all un-used config options in Megatron-FSDP's version of DDPConfig.
    • @BoxiangW and I discussed this during our meetings, but this is just the first baby step to perhaps completely annihilating the second variant of the DDPConfig so it doesn't confuse users, and flatten all necessary arguments into Megatron-FSDP directly.
  • Updated documentation on DDP, where all-reduce should not be called repetitively during gradient accumulation, and added warning messages for the user to zero the gradient buffer every step when this kind of behavior is detected.
    • We have a variety of options to handle this such as no_sync in Megatron-LM and an even simpler sync() / MegatronFSDP.set_model_auto_sync() for Megatron-external use (the opposite of no_sync that basically calls all the necessary functions to make Megatron-FSDP low-code in a vanilla training loop).

Tests

  • Added unit tests that cover all relevant sharding and mixed-precision strategies, as well as changing the gradient communication data-type mid-flight (which is not allowed for NCCL UBR).

All performance tests below use the following configuration (unless otherwise specified):

  • Llama 8B
    • The larger the model, the more performance improvements we can observe using low-precision communication, because the volume of communications will increase.
  • FP32 Main Parameters and Main Gradients
  • FP32 Gradient Accumulation and Reduction
  • FP8 Delayed Scaling + Parameter AG
  • optim_grads_params
    • For HFSDP, --outer-dp-sharding-strategy optim and --num-distributed-optimizer-instances 2.
  • Full Activation Recompute
  • GBS128 / MBS1
  • NCCL User Buffer Registration / FSDP Double Buffer
    • --use-nccl-ub --fsdp-manual-registration --fsdp-double-buffer for NCCL UB perf experiments.

Performance & Accuracy Parity with FP32 Gradient Communication + Accumulation (Reduce-Scatter)

  • With both communication and accumulation set to FP32 (still the default in Megatron-LM), the performance and accuracy are identical with main branch.
# Main Branch (FP32 Gradient Reduce-Scatter)
[2026-01-29 09:54:18.103917] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 10925.4 | throughput per GPU (TFLOP/s/GPU): 617.5 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.154529E+00 | loss scale: 1.0 | grad norm: 5.414 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Mixed-Precision (FP32 Gradient Reduce-Scatter)
[2026-02-12 16:50:24.452656] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 10919.5 | throughput per GPU (TFLOP/s/GPU): 617.9 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.154451E+00 | loss scale: 1.0 | grad norm: 5.412 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

Mixed-Precision BF16 Gradient Communication + FP32 Gradient Reduction / Accumulation

  • Setting --megatron-fsdp-grad-comm-dtype bf16 enables BF16 communication and FP32 reduction / accumulation if NCCL 2.27+ is used with NCCL UBR for pure FSDP.
    • Compared to FP32 communications, we have a 6% speedup of nearly 700 ms saved per global step for model compute as small as Llama 8B.
    • Loss is equivalent to the FP32 communications case, which implies FP32 reduction over NCCL.
# Mixed Precision (BF16 Comm / FP32 Reduce-Accum via NCCL UBR SymMem)
[2026-02-12 17:10:29.641373] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 10274.1 | throughput per GPU (TFLOP/s/GPU): 656.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.154832E+00 | loss scale: 1.0 | grad norm: 5.411 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • The performance improvement is even more apparent with a large model and smaller compute, such as with Llama 70B @ 1K SeqLen, where we get loss parity and a 22.5% speedup in communications compared to main branch!
# Llama 70B Main Branch (FP32 Gradient Reduce-Scatter + NCCL UB)
[2026-02-16 21:53:58.402418] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 3300.5 | throughput per GPU (TFLOP/s/GPU): 261.3 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 2.233609E+00 | loss scale: 1.0 | grad norm: 26.389 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Llama 70B Mixed Precision (BF16 Comm / FP32 Reduce-Accum via NCCL UBR SymMem)
[2026-02-16 19:59:23.416987] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2560.0 | throughput per GPU (TFLOP/s/GPU): 336.8 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 2.234979E+00 | loss scale: 1.0 | grad norm: 26.475 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

HFSDP Performance & Accuracy Tests (BF16 Gradient Communication + FP32 Gradient Reduction / Accumulation)

  • --num-distributed-optimizer-instances 2 and --outer-dp-sharding-strategy optim has parity on loss after 100 steps, and is just shy of 4x faster (3.62x) per global batch from 4 Nodes on Llama 8B.
    • IB domain reduce-scatter requires NCCL 2.29U1+, while all-reduce is not currently supported for IB. The result below was generated with v2.27.7, so DP-Outer reductions were NOT performed in FP32! This explains the slight loss disparity vs. other Llama 8B benchmarks.
# HFSDP 4-Node (DP-Outer=2, DP-Shard=16) + BF16 Comm + FP32 Reduce / Accum + NCCL UBR
[2026-02-16 21:22:05.035121] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2838.5 | throughput per GPU (TFLOP/s/GPU): 594.2 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.152859E+00 | loss scale: 1.0 | grad norm: 3.607 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

Extra Tests

  • With the optim gradient fix, and GBS 128 / MBS 1, we have improved loss (5.48 vs. 5.56) and reduced gradient norm (19.143 vs. 22.110) as we are no longer duplicating the gradient on the local rank, i.e. grad_i + sum(grad_i) instead of the expected sum(grad_i).
    • This difference becomes more difficult to see the higher DP you have, since the DP-reduced sum dominates the local gradient in magnitude.
# Main (Optimizer Sharding / Llama 8B @ 1K / FP32 Main Params, Grads, Grad Comm)
[2026-02-17 08:41:03.817490] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2793.4 | throughput per GPU (TFLOP/s/GPU): 268.8 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.560552E+00 | loss scale: 1.0 | grad norm: 22.110 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Optim Bug-Fix (Optimizer Sharding / Llama 8B @ 1K / FP32 Main Params, Grads, Grad Comm)
[2026-02-17 08:52:45.746740] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2776.0 | throughput per GPU (TFLOP/s/GPU): 270.5 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.486691E+00 | loss scale: 1.0 | grad norm: 19.143 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • FP32 activations (no --bf16 argument) works without any issues with NCCL UBR.
# FP32 Activations (BF16 Grad Comms + NCCL UBR / SymMem)
[2026-02-12 18:09:30.972237] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 19961.1 | throughput per GPU (TFLOP/s/GPU): 158.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.620105E+00 | loss scale: 1.0 | grad norm: 6.975 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • Checking for NaN for all weight gradients with fully_shard(report_nan_in_param_grad=True) costs a slight performance regression of +5% global step time. Should only be turned on for debugging!
[2026-01-30 09:51:37.769937] iteration        6/15258789 | consumed samples:          768 | elapsed time per iteration (ms): 11512.2 | throughput per GPU (TFLOP/s/GPU): 586.0 | learning rate: 2.949118E-08 | global batch size:   128 | lm loss: 1.213158E+01 | loss scale: 1.0 | grad norm: 28.643 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

Future Work

  • param_comm_dtype doesn't have that much use right now outside of the already supported TransformerEngine FP8 AG, so we will defer this to the future when we have plans for quantized AG for non-FP8 parameters, which in itself requires some research into the effect of extra quantization operations on sharded parameters vs. un-sharded parameters in model training.
  • @cspades Add a MegatronFSDP.__init__(debug=False) argument for improved unit tests.

Appendix

Type-Promotion Examples

  • TL;DR Type-promotion is equivalent to casting everything to the higher precision before operation, and can affect the numerics through down-casts even when the output precision is lower than the type-promoted precision.
"""
Input DType: torch.float32
Output DType: torch.bfloat16

REDUCTION TESTS

Reduce with torch.float32, and cast to torch.bfloat16:
 tensor([ 11.8750, 252.0000, -31.0000], dtype=torch.bfloat16)
Reduce via type-promotion into torch.bfloat16:
 tensor([ 11.8750, 252.0000, -31.0000], dtype=torch.bfloat16)
Cast to torch.bfloat16, and then reduce:
 tensor([ 11.8750, 252.0000, -31.1250], dtype=torch.bfloat16)
Use torch.sum(torch.bfloat16) to reduce:
 tensor([ 11.8750, 252.0000, -31.1250], dtype=torch.bfloat16)

ACCUMULATION TESTS

torch.bfloat16.add_(torch.float32):
 tensor([   3.6094, -180.0000,   -0.7031], dtype=torch.bfloat16)
torch.sum(dtype=torch.float32).to(torch.bloat16):
 tensor([   3.6094, -180.0000,   -0.7031], dtype=torch.bfloat16)
torch.sum(dtype=torch.bfloat16):
 tensor([   3.6094, -180.0000,   -0.7188], dtype=torch.bfloat16)

-------------------

Input DType: torch.bfloat16
Output DType: torch.float32

REDUCTION TESTS

Reduce with torch.bfloat16, and cast to torch.float32:
 tensor([ 11.8750, 252.0000, -31.1250])
Reduce via type-promotion into torch.float32:
 tensor([ 11.8750, 252.2500, -31.0781])
Cast to torch.float32, and then reduce:
 tensor([ 11.8750, 252.2500, -31.0781])
Use torch.sum(torch.float32) to reduce:
 tensor([ 11.8750, 252.2500, -31.0781])

ACCUMULATION TESTS

torch.float32.add_(torch.bfloat16):
 tensor([   3.6099, -179.7130,   -0.7188])
torch.sum(dtype=torch.bfloat16).to(torch.float32):
 tensor([   3.6094, -180.0000,   -0.7188])
torch.sum(dtype=torch.float32):
 tensor([   3.6099, -179.7130,   -0.7188])
"""

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@cspades cspades requested a review from deepakn94 January 24, 2026 02:05
@cspades cspades self-assigned this Jan 24, 2026
@cspades cspades requested review from a team as code owners January 24, 2026 02:05
@ko3n1g ko3n1g requested a review from a team January 24, 2026 02:06
@ko3n1g ko3n1g added this to the Core 0.16 milestone Jan 24, 2026
@cspades cspades requested a review from shjwudp January 24, 2026 02:09
@cspades cspades added Expert Review Apply this label to indicate that your PR is ready for expert review. module: megatron-fsdp labels Jan 24, 2026
@cspades cspades force-pushed the cye/mfsdp-custom-dtype branch from 8714977 to c72e1d3 Compare January 24, 2026 02:50
@cspades cspades force-pushed the cye/mfsdp-custom-dtype branch from c72e1d3 to 107e81a Compare January 24, 2026 02:54
@cspades cspades added the bug Something isn't working label Feb 17, 2026
"""If true, reuse the grad buffer for param AG when using mxfp8 recipe. Should be
set to True only when fp8_recipe is mxfp8 and fp8_param_gather is True."""

use_megatron_fsdp: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this break a lot of things in other codebases?

Copy link
Member Author

@cspades cspades Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the megatron-fsdp DDPConfig for the PyPI package, and this entry in the dataclass is not used anywhere in the Megatron-FSDP source code, it's just used in Megatron-LM!

It is possible that someone is using DDPConfig outside of Megatron-LM's API, which is definitely not recommended since there is generally no good reason to do so, but they would have to use their own local booleans and flags instead of depending on configs that aren't even used by our source code.

Should I be worried? (It shouldn't be particularly difficult to migrate out of... basically don't use this dataclass as a placeholder for your framework arguments/parameters.)

Copy link
Member Author

@cspades cspades Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @akoumpa or @ananthsub maybe to chime in on the Automodel or MBridge side.

  • Automodel would have to deprecate these attributes and just use their own arguments without storing them in the Megatron-FSDP dataclass or entangling these attributes with their ML code.
  • MBridge should not be affected whatsoever except for preserve_fp32_weights (migrate to main_params_dtype=torch.float32), any mappings from MBridge configs to DDPConfig will not break, because megatron.core.distributed.distributed_data_parallel_config has not changed otherwise!
  • fully_shard() users don't need to worry about it much, just migrate to the new and super clean MixedPrecisionPolicy dataclass API.

@cspades cspades requested a review from jaredcasper February 17, 2026 18:59
if dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.int16:
return 2
elif dtype == torch.float32 or dtype == torch.int32:
elif dtype == torch.float32 or dtype == torch.int32 or torch.complex32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we really supporting these dtypes? Would it be better to raise error?

Copy link
Member Author

@cspades cspades Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary at all TBH but just adds some known dtypes so we don't need to allocate an empty Tensor to find out! Just quality-of-life changes that I used to use for grad_accum_dtype impl.

It's kind of sad that torch.dtype doesn't have an API that tells you this, you HAVE to allocate a tensor or know beforehand...

the native model weights become the main weights. Defaults to torch.float32.
"""

main_grads_dtype: Optional[torch.dtype] = torch.float32
Copy link
Member Author

@cspades cspades Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sync'd offline, default should definitely be FP32 (as before) for large scale model accuracy when training. This also controls the gradient accumulation, which should be type-promoted to FP32!

Users are free to control this, but quoting @deepakn94 we shouldn't let them shoot themselves in the foot on gradient precision, they should make this decision depending on their model size and dataset complexity.

So TL;DR:

  • Compute params are in whatever the native model params are in.
  • Main weights controlled by the main_params_dtype, which should be FP32 for optimization.
  • Main gradients controlled by the main_grads_dtype, which should be FP32 for high-precision gradient accumulation.

@cspades cspades added Final Review Apply this label to indicate that your PR is ready for final review. and removed Expert Review Apply this label to indicate that your PR is ready for expert review. labels Feb 17, 2026
…tron_fsdp to remove unnecessary attributes.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…heck.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…m_dtype by deactivating SymMem for gradients.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…t freed, both used to setup NCCL UB communication buckets.

Signed-off-by: Cory Ye <cye@nvidia.com>
…sharded buffers.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Copy link
Contributor

@shjwudp shjwudp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for helping clarify the necessity of implementing reduce-scatter based on A2A and the trade-offs with NCCL’s native reduce-scatter.

I think there’s still room to simplify this MR. If we keep the main goal in focus, perhaps the changes related to the gradient reduce pipeline and the bucket fetch/free operations could be reverted. A well-scoped PR focusing on a few clear objectives will make it easier to review, trace, and maintain.

if custom_grad_comm_dtype:
# Create a custom communication buffer with gbuf.
# Introduces copy and memory overhead.
unreduced_grad_bucket = gbuf.allocate_bucket_storage(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the GPU memory allocation mechanism here? Both line 3395 and line 3381 call the temporary buffer allocator, which seems to trigger additional memory allocations.

Copy link
Member Author

@cspades cspades Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for no_shard and optim only, we will require a secondary un-sharded communication buffer (in addition to the main gradient un-sharded buffer) to support low-precision communication.

This is only required when custom_grad_comm_dtype=True, so users are (supposed to be) able to turn off this feature, and consider the memory / communication trade-off. The only no_shard and optim cases where the memory would be allocated is if main_grads_dtype=torch.float32 and grad_comm_dtype=torch.bfloat16, or vice versa.

I agree the logic is complex and potentially difficult to maintain. Not supporting custom gradient communication for no_shard and optim is definitely something we can consider, as it is not clear to me whether the memory cost would ever be worth the communication speedup, but this allocation enables it.

Currently, the Megatron-LM default is main_grads_dtype=torch.float32 and grad_comm_dtype=torch.float32, derived from the original grad_reduce_in_fp32=True. This means that:

unreduced_grad_bucket = gbuf.fetch_bucket(
    dtype=mp_policy.grad_comm_dtype if gbuf.is_data_distributed else None
)

is an FP32 main gradient buffer, and:

# While un-reduced gradients are directly installed into communication
# buffers of grad_comm_dtype for sharded gradient buffers, un-sharded
# gradient buffers accumulate un-reduced gradients locally and may need
# dtype-custom communication buffers!
custom_grad_comm_dtype = (
    mp_policy.grad_comm_dtype is not None
    and unreduced_grad_bucket.data.dtype != mp_policy.grad_comm_dtype
)

is False, because unreduced_grad_bucket.data.dtype == mp_policy.grad_comm_dtype == torch.float32, so no extra memory is allocated in this case.

"""
self.temporary_bucket_allocator.free(self.bucket_index.bucket_id)

def reset_param_main_grad(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a separate reset_param_main_grad function necessary?

Copy link
Member Author

@cspades cspades Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not strictly necessary, but it seems like a good refactor to be more precise on when we dereference param.main_grad. There are a lot of use cases of free_bucket_storage that are not related to the main gradient, and this seems like something that permits more specificity on when we dereference it.

I believe I can revert this, i.e. putting reset_param_main_grad back into free_bucket_storage, with no change in behavior if you'd like, but it seems much more readable / transparent as separate functions since these two functions accomplish correlated not but similar functions. (The only required change is to remove if not is_data_distributed: return, since there are a few cases where I need to allocate and deallocate temporary communication buffers regardless of whether the buffer was sharded or not.)

Comment on lines +3490 to +3501
if custom_grad_comm_dtype:
# Free the bucket allocated for the DP-Shard reduction,
# which has been waited on already.
fsdp_grad_buffer.free_bucket_storage()
# Create a custom communication buffer with fsdp_grad_buffer.
# Introduces copy and memory overhead.
unreduced_grad = fsdp_grad_buffer.allocate_bucket_storage(
size=unreduced_grad.numel(),
dtype=mp_policy.grad_comm_dtype,
device=unreduced_grad.device,
init_values=unreduced_grad,
).data
Copy link
Member Author

@cspades cspades Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shjwudp Also I believe I need to del param.main_grad / reset_param_main_grad here with fsdp_grad_buffer.free_bucket_storage(). I was on the fence before on the technicality of whether we need to do this, or if any other APIs were using this property during the gradient reduction, but if we are freeing its storage, we should dereference it. Agree?

(This has no bearing on my previous comment about reset_param_main_grad's refactor, I simply was not sure in this case. 😵 )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working Final Review Apply this label to indicate that your PR is ready for final review. module: megatron-fsdp

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants