[Megatron-FSDP] Add dtype customization to Megatron-FSDP.#3067
[Megatron-FSDP] Add dtype customization to Megatron-FSDP.#3067cspades wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
8714977 to
c72e1d3
Compare
c72e1d3 to
107e81a
Compare
| """If true, reuse the grad buffer for param AG when using mxfp8 recipe. Should be | ||
| set to True only when fp8_recipe is mxfp8 and fp8_param_gather is True.""" | ||
|
|
||
| use_megatron_fsdp: bool = False |
There was a problem hiding this comment.
Wouldn't this break a lot of things in other codebases?
There was a problem hiding this comment.
So this is the megatron-fsdp DDPConfig for the PyPI package, and this entry in the dataclass is not used anywhere in the Megatron-FSDP source code, it's just used in Megatron-LM!
It is possible that someone is using DDPConfig outside of Megatron-LM's API, which is definitely not recommended since there is generally no good reason to do so, but they would have to use their own local booleans and flags instead of depending on configs that aren't even used by our source code.
Should I be worried? (It shouldn't be particularly difficult to migrate out of... basically don't use this dataclass as a placeholder for your framework arguments/parameters.)
There was a problem hiding this comment.
cc @akoumpa or @ananthsub maybe to chime in on the Automodel or MBridge side.
- Automodel would have to deprecate these attributes and just use their own arguments without storing them in the Megatron-FSDP dataclass or entangling these attributes with their ML code.
- MBridge should not be affected whatsoever except for
preserve_fp32_weights(migrate tomain_params_dtype=torch.float32), any mappings from MBridge configs to DDPConfig will not break, becausemegatron.core.distributed.distributed_data_parallel_confighas not changed otherwise! fully_shard()users don't need to worry about it much, just migrate to the new and super cleanMixedPrecisionPolicydataclass API.
| if dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.int16: | ||
| return 2 | ||
| elif dtype == torch.float32 or dtype == torch.int32: | ||
| elif dtype == torch.float32 or dtype == torch.int32 or torch.complex32: |
There was a problem hiding this comment.
Are we really supporting these dtypes? Would it be better to raise error?
There was a problem hiding this comment.
It's not necessary at all TBH but just adds some known dtypes so we don't need to allocate an empty Tensor to find out! Just quality-of-life changes that I used to use for grad_accum_dtype impl.
It's kind of sad that torch.dtype doesn't have an API that tells you this, you HAVE to allocate a tensor or know beforehand...
| the native model weights become the main weights. Defaults to torch.float32. | ||
| """ | ||
|
|
||
| main_grads_dtype: Optional[torch.dtype] = torch.float32 |
There was a problem hiding this comment.
Sync'd offline, default should definitely be FP32 (as before) for large scale model accuracy when training. This also controls the gradient accumulation, which should be type-promoted to FP32!
Users are free to control this, but quoting @deepakn94 we shouldn't let them shoot themselves in the foot on gradient precision, they should make this decision depending on their model size and dataset complexity.
So TL;DR:
- Compute params are in whatever the native model params are in.
- Main weights controlled by the main_params_dtype, which should be FP32 for optimization.
- Main gradients controlled by the main_grads_dtype, which should be FP32 for high-precision gradient accumulation.
…tron_fsdp to remove unnecessary attributes. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…heck. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…m_dtype by deactivating SymMem for gradients. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…t freed, both used to setup NCCL UB communication buckets. Signed-off-by: Cory Ye <cye@nvidia.com>
…sharded buffers. Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
3f1a68f to
9e9ff0c
Compare
shjwudp
left a comment
There was a problem hiding this comment.
Thanks for helping clarify the necessity of implementing reduce-scatter based on A2A and the trade-offs with NCCL’s native reduce-scatter.
I think there’s still room to simplify this MR. If we keep the main goal in focus, perhaps the changes related to the gradient reduce pipeline and the bucket fetch/free operations could be reverted. A well-scoped PR focusing on a few clear objectives will make it easier to review, trace, and maintain.
| if custom_grad_comm_dtype: | ||
| # Create a custom communication buffer with gbuf. | ||
| # Introduces copy and memory overhead. | ||
| unreduced_grad_bucket = gbuf.allocate_bucket_storage( |
There was a problem hiding this comment.
Could you explain the GPU memory allocation mechanism here? Both line 3395 and line 3381 call the temporary buffer allocator, which seems to trigger additional memory allocations.
There was a problem hiding this comment.
Yes, for no_shard and optim only, we will require a secondary un-sharded communication buffer (in addition to the main gradient un-sharded buffer) to support low-precision communication.
This is only required when custom_grad_comm_dtype=True, so users are (supposed to be) able to turn off this feature, and consider the memory / communication trade-off. The only no_shard and optim cases where the memory would be allocated is if main_grads_dtype=torch.float32 and grad_comm_dtype=torch.bfloat16, or vice versa.
I agree the logic is complex and potentially difficult to maintain. Not supporting custom gradient communication for no_shard and optim is definitely something we can consider, as it is not clear to me whether the memory cost would ever be worth the communication speedup, but this allocation enables it.
Currently, the Megatron-LM default is main_grads_dtype=torch.float32 and grad_comm_dtype=torch.float32, derived from the original grad_reduce_in_fp32=True. This means that:
unreduced_grad_bucket = gbuf.fetch_bucket(
dtype=mp_policy.grad_comm_dtype if gbuf.is_data_distributed else None
)
is an FP32 main gradient buffer, and:
# While un-reduced gradients are directly installed into communication
# buffers of grad_comm_dtype for sharded gradient buffers, un-sharded
# gradient buffers accumulate un-reduced gradients locally and may need
# dtype-custom communication buffers!
custom_grad_comm_dtype = (
mp_policy.grad_comm_dtype is not None
and unreduced_grad_bucket.data.dtype != mp_policy.grad_comm_dtype
)
is False, because unreduced_grad_bucket.data.dtype == mp_policy.grad_comm_dtype == torch.float32, so no extra memory is allocated in this case.
| """ | ||
| self.temporary_bucket_allocator.free(self.bucket_index.bucket_id) | ||
|
|
||
| def reset_param_main_grad(self): |
There was a problem hiding this comment.
Why is a separate reset_param_main_grad function necessary?
There was a problem hiding this comment.
It's not strictly necessary, but it seems like a good refactor to be more precise on when we dereference param.main_grad. There are a lot of use cases of free_bucket_storage that are not related to the main gradient, and this seems like something that permits more specificity on when we dereference it.
I believe I can revert this, i.e. putting reset_param_main_grad back into free_bucket_storage, with no change in behavior if you'd like, but it seems much more readable / transparent as separate functions since these two functions accomplish correlated not but similar functions. (The only required change is to remove if not is_data_distributed: return, since there are a few cases where I need to allocate and deallocate temporary communication buffers regardless of whether the buffer was sharded or not.)
| if custom_grad_comm_dtype: | ||
| # Free the bucket allocated for the DP-Shard reduction, | ||
| # which has been waited on already. | ||
| fsdp_grad_buffer.free_bucket_storage() | ||
| # Create a custom communication buffer with fsdp_grad_buffer. | ||
| # Introduces copy and memory overhead. | ||
| unreduced_grad = fsdp_grad_buffer.allocate_bucket_storage( | ||
| size=unreduced_grad.numel(), | ||
| dtype=mp_policy.grad_comm_dtype, | ||
| device=unreduced_grad.device, | ||
| init_values=unreduced_grad, | ||
| ).data |
There was a problem hiding this comment.
@shjwudp Also I believe I need to del param.main_grad / reset_param_main_grad here with fsdp_grad_buffer.free_bucket_storage(). I was on the fence before on the technicality of whether we need to do this, or if any other APIs were using this property during the gradient reduction, but if we are freeing its storage, we should dereference it. Agree?
(This has no bearing on my previous comment about reset_param_main_grad's refactor, I simply was not sure in this case. 😵 )
What does this PR do ?
main_params_dtype,main_grads_dtype, andgrad_comm_dtypeto Megatron-FSDP.grad_accum_dtype(high-precision gradient reduction) will be handled by NCCL UBR / SymMem forv2.27+.All planned performance benchmarks completed with no bugs. Ready for expert & final review!
Details
Mixed-Precision Support (
megatron_fsdp.MixedPrecisionPolicy)main_params_dtype/--megatron-fsdp-main-params-dtype(🍀 NEW! 🍀) andmain_grads_dtype/--megatron-fsdp-main-grads-dtype(🍀 NEW! 🍀) are simple generalizations ofpreserve_fp32_weights(⛔ DEPRECATED ⛔) andgrad_reduce_in_fp32.grad_comm_dtype/--megatron-fsdp-grad-comm-dtype(🍀 NEW! 🍀) controls the data-type used for gradient communication (all-reduce & reduce-scatter).main_grads_dtypeis not equivalent tograd_comm_dtype, a communication bucket with the communication data-type will be allocated. Otherwise, and if not specified, themain_grads_dtypewill be the communication data-type.Megatron-FSDP Gradient Lifecycle
To summarize the gradient pipeline of Megatron-FSDP for the uninitiated:
main, this communication bucket matches the main gradient buffer data-type. So we cannot have low-precision communication buckets with high-precision main gradients.grad_comm_dtype, to support low-precision communication and high-precision reduction with NCCL (v2.27+).main_grads_dtype, typically FP32.no_shardandoptim, this is an local all-reduce or reduce-scatter that can only be called once per optimization cycle to avoid corrupt gradients.no_shardandoptimsharding strategies definitively do not permit a second un-sharded memory allocation in order to maintain both communication and accumulation buffers for the gradient (one for BF16 communication, another for FP32 accumulation) until we finally perform the only DP-reduction right before the optimization step. Thus, we temporarily allocate / deallocate a BF16 communication buffer right before gradient reduction, while persistently allocating an FP32 main gradient bucket.optim_gradsandoptim_grads_params, this is a reduce-scatter into the allocated communication bucket, and shards of the result are accumulated into the main gradient buffer. Because we reduce every layer of every step, we only persistently hold onto a reduced and accumulated shard of the gradient.🚨 Bug Fixes 🚨
optimhad corrupted gradients, where the main gradient would be reduce-scattered into a temporary shard, but the reduced shard would be accumulated back into the source main gradient shard (without zero'ing the buffer), leading to duplicate gradients.copyand+=cases to the DP-Shard gradient reduction.(...)representing the reduced gradient andgNrepresenting the pre-reduce accumulated gradient:(g1 + g2)g1 + (g1 + g2)torch.empty_liketemporary shard, the bug would have doubled the gradient when usingoptim, i.e.(g1 + g2) += (g1 + g2)!optimwe copy the reduced gradient shard into the main gradient buffer if a communication buffer was allocated, otherwise the reduce-scatter directly updates the shard of the main gradient buffer. (Same forno_shardas well, but using all-reduce and copying the reduced un-sharded gradient.)Minor Edits
free_bucket_storage()to remove the criteria that only deallocates buckets for sharded buffers and factor out theparam.main_gradreset toreset_param_main_grad().fetch_bucket()will only allocate temporary buckets if the data-type is different, or if the buffer is sharded. So there is a loophole where a custom data-type allocation will not be deallocated if the buffer is sharded.AllGatherPipeline.release_bucket().reset_param_main_grad()only needs to be called when the FSDP gradient buffer on DP-Shard has completed its collectives and installed the reduced gradient in local data.param.main_gradwill first point to the unreduced gradient bucket, and then point to the DP-Shard reduced main gradient buffer data (or a custom data-type variant of the aforementioned values).check_for_nan_in_gradfor Megatron-LM (called instart_grad_sync) andreport_nan_in_param_gradforfully_shard, which both default toFalseinMegatronFSDP.report_nan_in_param_gradin particular is an expensive operation that can degrade performance by around 5%, but can be extremely useful for quickly debugging the source of NaNs, whether they come from Megatron-FSDP or user models.no_syncin Megatron-LM and an even simplersync()/MegatronFSDP.set_model_auto_sync()for Megatron-external use (the opposite ofno_syncthat basically calls all the necessary functions to make Megatron-FSDP low-code in a vanilla training loop).Tests
All performance tests below use the following configuration (unless otherwise specified):
optim_grads_params--outer-dp-sharding-strategy optimand--num-distributed-optimizer-instances 2.--use-nccl-ub--fsdp-manual-registration--fsdp-double-bufferfor NCCL UB perf experiments.Performance & Accuracy Parity with FP32 Gradient Communication + Accumulation (Reduce-Scatter)
mainbranch.Mixed-Precision BF16 Gradient Communication + FP32 Gradient Reduction / Accumulation
--megatron-fsdp-grad-comm-dtype bf16enables BF16 communication and FP32 reduction / accumulation if NCCL2.27+is used with NCCL UBR for pure FSDP.mainbranch!HFSDP Performance & Accuracy Tests (BF16 Gradient Communication + FP32 Gradient Reduction / Accumulation)
--num-distributed-optimizer-instances 2and--outer-dp-sharding-strategy optimhas parity on loss after 100 steps, and is just shy of 4x faster (3.62x) per global batch from 4 Nodes on Llama 8B.2.29U1+, while all-reduce is not currently supported for IB. The result below was generated withv2.27.7, so DP-Outer reductions were NOT performed in FP32! This explains the slight loss disparity vs. other Llama 8B benchmarks.Extra Tests
optimgradient fix, and GBS 128 / MBS 1, we have improved loss (5.48vs.5.56) and reduced gradient norm (19.143vs.22.110) as we are no longer duplicating the gradient on the local rank, i.e.grad_i + sum(grad_i)instead of the expectedsum(grad_i).--bf16argument) works without any issues with NCCL UBR.NaNfor all weight gradients withfully_shard(report_nan_in_param_grad=True)costs a slight performance regression of +5% global step time. Should only be turned on for debugging!Future Work
param_comm_dtypedoesn't have that much use right now outside of the already supported TransformerEngine FP8 AG, so we will defer this to the future when we have plans for quantized AG for non-FP8 parameters, which in itself requires some research into the effect of extra quantization operations on sharded parameters vs. un-sharded parameters in model training.MegatronFSDP.__init__(debug=False)argument for improved unit tests.Appendix
Type-Promotion Examples
Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.