Skip to content

Add support for DGPO (ICLR 2026) to GRPO#5102

Open
YanqiDai wants to merge 8 commits intohuggingface:mainfrom
YanqiDai:grpo-dgpo
Open

Add support for DGPO (ICLR 2026) to GRPO#5102
YanqiDai wants to merge 8 commits intohuggingface:mainfrom
YanqiDai:grpo-dgpo

Conversation

@YanqiDai
Copy link

@YanqiDai YanqiDai commented Feb 15, 2026

What does this PR do?

Add DGPO (Difficulty-Aware Group Policy Optimization) support to GRPO.

References: MathForge (ICLR 2026) GitHub, Paper.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). No
  • Did you read the contributor guideline,
    Pull Request section? Yes
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case. No
  • Did you make sure to update the documentation with your changes? Yes, it's in docs/source/grpo_trainer.md and docs/source/paper_index.md.
  • Did you write any new necessary tests? Yes, it's test_training_dgpo in tests/test_grpo_trainer.py.

Motivation

This PR integrates DGPO (Difficulty-Aware Group Policy Optimization) from MathForge (ICLR 2026, paper) into the GRPO trainer. DGPO improves group-based RL by:

  1. Difficulty-aware advantage scaling — Using MAD (Mean Absolute Deviation) instead of standard deviation when normalizing advantages (DGAE), which can address the implicit imbalance where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty.
  2. Difficulty-aware question weighting (DQW) — Assigning higher weight to harder questions (lower mean accuracy) so that the policy focuses more on improving on difficult items while keeping a fixed total weight budget.
  3. Valid token-level loss averaging — Scaling advantages so that (a) only valid samples (where std_rewards != 0) contribute to the effective normalizer, and (b) multi-GPU training is balanced by accounting for per-process valid token counts. This yields a proper token-level average loss across valid data and devices.

These options are useful for mathematical reasoning and other settings where question difficulty is heterogeneous and reward variance can be zero for some groups.

Changes

Configuration (grpo_config.py)

  • use_dgpo_dgae (bool, default False): When True and scale_rewards != "none", advantages are normalized by MAD instead of standard deviation (DGAE).
  • use_dgpo_dqw (bool, default False): When True, advantages are multiplied by per-question difficulty weights (DQW). Zero-variance questions get weight 1; others are weighted by a softmax over negative mean accuracy so that harder questions get higher weight; weights sum to num_questions.
  • dgpo_dqw_temp (float, default 2.0): Temperature for the DQW softmax.
  • dgpo_dqw_acc_reward_index (int, default 0): Index of the accuracy reward in reward_funcs used to compute per-question mean accuracy for DQW.

All new parameters are documented with a reference to the MathForge paper (ICLR 2026).

Trainer (grpo_trainer.py)

  1. DGAE
    In both sum_then_normalize and normalize_then_sum branches, when use_dgpo_dgae is True and rewards are scaled, the advantage denominator uses MAD instead of std:
    advantage = (reward - mean) / (MAD + eps).

  2. DQW
    After advantage computation (and after the valid-token scaling described below), when use_dgpo_dqw is True:

    • Per-question mean and std of the accuracy reward at dgpo_dqw_acc_reward_index are computed.
    • Zero-variance questions keep weight 1.
    • For the rest, weights are (num_questions - num_zero_variance) * softmax(-mean_acc / dgpo_dqw_temp); questions with mean accuracy 0 or NaN are treated as “easiest” (mean set to 1 in the softmax) so they receive less weight.
    • These weights are expanded to per-sample and multiplied onto advantages (no separate weighting in the loss).
  3. Valid token-level loss averaging (when use_dgpo_dgae or use_dgpo_dqw is True)

    • Valid sample: any sample for which std_rewards != 0 (i.e. ~is_std_zero).
    • Before the per-process slice:
      • completion_length = completion_mask.sum(dim=1) is gathered across processes to get gathered_completion_length.
      • global_balancing_ratio = num_processes * local_completion_length_sum / global_completion_length_sum is computed (used later).
      • If there is at least one valid sample, zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum, where valid_completion_length_sum is the sum of completion lengths over valid samples only; otherwise zero_mask_ratio = 1.0.
      • Full advantages are multiplied by zero_mask_ratio so that the effective normalizer ignores invalid (zero-variance) samples.
    • After the slice that keeps only the local part of the data, local advantages are multiplied by global_balancing_ratio so that the loss is balanced across processes by valid token count (valid token-level averaging across devices).

Tests

  • test_training_dgpo in tests/test_grpo_trainer.py: Runs a short training with use_dgpo_dgae=True, use_dgpo_dqw=True, dgpo_dqw_temp=2.0, and dgpo_dqw_acc_reward_index=0, and checks that train_loss is recorded.

Copy link
Collaborator

@LeonEricsson LeonEricsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few comments.

also: _generate_and_score is getting too dense with this PR. DGAE/DQW + valid-token balancing logic and the existing multi-objective aggregation both add substantial branching/state. It's becoming hard to follow and validate each transformation in isolation.

I think it makes sense to pull most of these out into separate helpers?

Comment on lines 168 to 175
### DGPO (Difficulty-Aware Group Policy Optimization)

DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`].

- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance
where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty.
- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### DGPO (Difficulty-Aware Group Policy Optimization)
DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`].
- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance
where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty.
- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal.

Extend the paper index section with this information instead

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have already moved this information to the paper index section.

num_questions, device=advantages.device, dtype=advantages.dtype
)
if num_zero_variance_questions < num_questions:
# For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For mean accuracy 0 (all wrong) or NaN, set difficulty to -1 so they get less weight
# mean accuracy == 0 (all wrong) or NaN are remapped to 1.0 before softmax so they get less weight```

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit

but also, doesn't this imply rewards have to be >0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. This is only the judgment and operation for the accuracy reward. We default the accuracy reward range to [0,1].

Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, the denominator when
scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard deviation, i.e.
advantage = (reward - mean) / (MAD + eps) with MAD = mean(|reward - mean|). Introduced in the [MathForge
paper](https://huggingface.co/papers/2601.20614) (ICLR 2026).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
paper](https://huggingface.co/papers/2601.20614) (ICLR 2026).
paper](https://huggingface.co/papers/2601.20614).

applies to all instances below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

YanqiDai and others added 6 commits February 20, 2026 14:51
Updated the DGPO section to clarify its mechanisms and usage in TRL, including details on DGAE and DQW.
Removed DGPO section and its related details from the documentation.
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Updated comment to clarify handling of mean accuracy for questions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments