Skip to content

Fix loss scaling and token aggregation to use only data parallel group#39674

Open
Krish0909 wants to merge 3 commits intohuggingface:mainfrom
Krish0909:fix/loss-scaling-tp-cp
Open

Fix loss scaling and token aggregation to use only data parallel group#39674
Krish0909 wants to merge 3 commits intohuggingface:mainfrom
Krish0909:fix/loss-scaling-tp-cp

Conversation

@Krish0909
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR fixes a bug in the Trainer where loss and token counts were previously being scaled across all Accelerate processes—including tensor parallel (TP) and context parallel (CP) meshes—leading to inflated training losses when using composable parallelism. After this change, loss scaling and token aggregation will only consider the data parallel group, aligning TP/CP runs with pure DDP behavior.

Fixes: Fixes #39648

Changes

Loss scaling: Replaced self.accelerator.num_processes with self.accelerator.state.num_data_parallel_processes when applying average_tokens_across_devices.

Token aggregation: Updated batching logic to use accelerator.reduce(..., group_type="data") for summing tokens only across the data parallel group.

Motivation and Context

When using Accelerate's composable parallelism (TP/CP), the original implementation erroneously multiplied the loss by the total number of processes (DP × TP × CP). This resulted in losses that were N× larger (where N = TP × CP), making training logs and LR schedulers behave incorrectly. By restricting scaling to the data parallel group, we restore consistency with pure DDP runs.

Testing

All existing Trainer integration tests pass (no regressions).

Manual verification:

Ran run_glue.py on MRPC with --tensor_parallel_size 2 --context_parallel_size 2. Logged losses every 10 steps.

Compared against a pure DDP run (no TP/CP flags). Loss trajectories matched within floating-point tolerance.

Before submitting

Who can review?

Trainer: @zach-huggingface, @SunMarc

Accelerate integration: @SunMarc, @zach-huggingface

@Krish0909 Krish0909 changed the title Fix loss scaling and token aggregation to use only data parallel group #39648 Fix loss scaling and token aggregation to use only data parallel group Jul 25, 2025
@S1ro1
Copy link
Copy Markdown
Contributor

S1ro1 commented Jul 26, 2025

We aren't 100% sure of the API we'll go with in the PR you mentioned, so it's subject to change. Thank you for the contribution though! Also mind me asking, afaik we don't have num_data_parallel_processes in accelerate (yet), how did you come up with that?

@Krish0909
Copy link
Copy Markdown
Contributor Author

We aren't 100% sure of the API we'll go with in the PR you mentioned, so it's subject to change. Thank you for the contribution though! Also mind me asking, afaik we don't have num_data_parallel_processes in accelerate (yet), how did you come up with that?

You're absolutely right—num_data_parallel_processes isn't currently in accelerate. I added it as part of a forward-looking design to align with how AcceleratorState handles other parallelism dimensions like num_processes and num_mixed_precision_processes. I thought having explicit separation could be useful in scenarios with hybrid parallelism setups.

That said, I completely understand that the API is still evolving. I'm happy to adapt this PR once there's a clearer direction from the core team or if you'd prefer me to refactor to avoid the placeholder for now.

Let me know how you'd like me to proceed!

@srrk-GreenMan
Copy link
Copy Markdown

srrk-GreenMan commented Jul 27, 2025

Sorry for interupt. If you are looking forward to add a new variable of Accelerator, why don't we use new attributes of the model? (ex. model.dp_size, model.tp_size etc). I think when the model parallel is applied by the function "fully_shard", we can use the device mesh names and its shape.

@S1ro1
Copy link
Copy Markdown
Contributor

S1ro1 commented Jul 27, 2025

@Krish0909 It's totally fine, we aim to add properties as such in the PR you mentioned anyway, so this is probably gonna be very close to final.

@srrk-GreenMan we'd like to avoid adding this to the model itself, as it becomes transformers specific. We'll probably allow users to take properties as such from the ParallelismConfig

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants