Add support for DGPO (ICLR 2026) to GRPO#5102
Add support for DGPO (ICLR 2026) to GRPO#5102YanqiDai wants to merge 8 commits intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
a few comments.
also: _generate_and_score is getting too dense with this PR. DGAE/DQW + valid-token balancing logic and the existing multi-objective aggregation both add substantial branching/state. It's becoming hard to follow and validate each transformation in isolation.
I think it makes sense to pull most of these out into separate helpers?
docs/source/grpo_trainer.md
Outdated
| ### DGPO (Difficulty-Aware Group Policy Optimization) | ||
|
|
||
| DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`]. | ||
|
|
||
| - **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance | ||
| where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty. | ||
| - **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal. | ||
|
|
There was a problem hiding this comment.
| ### DGPO (Difficulty-Aware Group Policy Optimization) | |
| DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`]. | |
| - **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance | |
| where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty. | |
| - **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal. |
Extend the paper index section with this information instead
There was a problem hiding this comment.
I have already moved this information to the paper index section.
trl/trainer/grpo_trainer.py
Outdated
| num_questions, device=advantages.device, dtype=advantages.dtype | ||
| ) | ||
| if num_zero_variance_questions < num_questions: | ||
| # For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight |
There was a problem hiding this comment.
| # For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight | |
| # mean accuracy == 0 (all wrong) or NaN are remapped to 1.0 before softmax so they get less weight``` |
There was a problem hiding this comment.
style nit
but also, doesn't this imply rewards have to be >0?
There was a problem hiding this comment.
Thanks. This is only the judgment and operation for the accuracy reward. We default the accuracy reward range to [0,1].
trl/trainer/grpo_config.py
Outdated
| Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, the denominator when | ||
| scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard deviation, i.e. | ||
| advantage = (reward - mean) / (MAD + eps) with MAD = mean(|reward - mean|). Introduced in the [MathForge | ||
| paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). |
There was a problem hiding this comment.
| paper](https://huggingface.co/papers/2601.20614) (ICLR 2026). | |
| paper](https://huggingface.co/papers/2601.20614). |
applies to all instances below
Updated the DGPO section to clarify its mechanisms and usage in TRL, including details on DGAE and DQW.
Removed DGPO section and its related details from the documentation.
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Updated comment to clarify handling of mean accuracy for questions.
What does this PR do?
Add DGPO (Difficulty-Aware Group Policy Optimization) support to GRPO.
References: MathForge (ICLR 2026) GitHub, Paper.
Before submitting
Pull Request section? Yes
to it if that's the case. No
docs/source/grpo_trainer.mdanddocs/source/paper_index.md.test_training_dgpointests/test_grpo_trainer.py.Motivation
This PR integrates DGPO (Difficulty-Aware Group Policy Optimization) from MathForge (ICLR 2026, paper) into the GRPO trainer. DGPO improves group-based RL by:
std_rewards != 0) contribute to the effective normalizer, and (b) multi-GPU training is balanced by accounting for per-process valid token counts. This yields a proper token-level average loss across valid data and devices.These options are useful for mathematical reasoning and other settings where question difficulty is heterogeneous and reward variance can be zero for some groups.
Changes
Configuration (
grpo_config.py)use_dgpo_dgae(bool, defaultFalse): When True andscale_rewards != "none", advantages are normalized by MAD instead of standard deviation (DGAE).use_dgpo_dqw(bool, defaultFalse): When True, advantages are multiplied by per-question difficulty weights (DQW). Zero-variance questions get weight 1; others are weighted by a softmax over negative mean accuracy so that harder questions get higher weight; weights sum tonum_questions.dgpo_dqw_temp(float, default2.0): Temperature for the DQW softmax.dgpo_dqw_acc_reward_index(int, default0): Index of the accuracy reward inreward_funcsused to compute per-question mean accuracy for DQW.All new parameters are documented with a reference to the MathForge paper (ICLR 2026).
Trainer (
grpo_trainer.py)DGAE
In both
sum_then_normalizeandnormalize_then_sumbranches, whenuse_dgpo_dgaeis True and rewards are scaled, the advantage denominator uses MAD instead of std:advantage = (reward - mean) / (MAD + eps).DQW
After advantage computation (and after the valid-token scaling described below), when
use_dgpo_dqwis True:dgpo_dqw_acc_reward_indexare computed.(num_questions - num_zero_variance) * softmax(-mean_acc / dgpo_dqw_temp); questions with mean accuracy 0 or NaN are treated as “easiest” (mean set to 1 in the softmax) so they receive less weight.Valid token-level loss averaging (when
use_dgpo_dgaeoruse_dgpo_dqwis True)std_rewards != 0(i.e.~is_std_zero).completion_length = completion_mask.sum(dim=1)is gathered across processes to getgathered_completion_length.global_balancing_ratio = num_processes * local_completion_length_sum / global_completion_length_sumis computed (used later).zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum, wherevalid_completion_length_sumis the sum of completion lengths over valid samples only; otherwisezero_mask_ratio = 1.0.zero_mask_ratioso that the effective normalizer ignores invalid (zero-variance) samples.global_balancing_ratioso that the loss is balanced across processes by valid token count (valid token-level averaging across devices).Tests
test_training_dgpointests/test_grpo_trainer.py: Runs a short training withuse_dgpo_dgae=True,use_dgpo_dqw=True,dgpo_dqw_temp=2.0, anddgpo_dqw_acc_reward_index=0, and checks thattrain_lossis recorded.