Skip to content

Add num_generations_eval parameter for efficient evaluation#4458

Merged
qgallouedec merged 8 commits intohuggingface:mainfrom
mingxuetian:feature/add-num-generations-eval
Nov 24, 2025
Merged

Add num_generations_eval parameter for efficient evaluation#4458
qgallouedec merged 8 commits intohuggingface:mainfrom
mingxuetian:feature/add-num-generations-eval

Conversation

@mingxuetian
Copy link
Contributor

@mingxuetian mingxuetian commented Nov 5, 2025

Add num_generations_eval parameter for efficient evaluation

This is my first open-source PR contribution, I would greatly appreciate any feedback on areas for improvement. Please don't hesitate to suggest changes - I'm eager to learn and make this contribution as good as possible!

What does this PR do?

This PR adds support for using a different number of generations during evaluation compared to training in GRPOTrainer. This allows users to save computation time during evaluation while maintaining training quality.

Fixes

Closes #3539
Closes #3566

Motivation

During training, multiple generations per prompt are often needed for better exploration and diversity. However, during evaluation, fewer generations are typically sufficient to assess model performance. This feature enables more efficient evaluation without compromising training effectiveness.

For example, users can train with 16 generations per prompt but evaluate with only 2 generations, reducing evaluation time by 8x.

Changes Made

1. Added num_generations_eval parameter to GRPOConfig

File: trl/trainer/grpo_config.py

Added a new optional parameter after num_generations:

num_generations_eval: int | None = field(
    default=None,
    metadata={
        "help": "Number of generations to sample during evaluation. If `None`, uses the value of "
        "`num_generations`. This allows using fewer generations during evaluation to save computation. "
        "Maintains backward compatibility with previous configuration files."
    },
)

2. Modified GRPOTrainer.__init__ to store the parameter

File: trl/trainer/grpo_trainer.py

Added line 383 to store the new parameter:

self.num_generations = args.num_generations  # = G in the GRPO paper
self.num_generations_eval = args.num_generations_eval  # NEW LINE
self.chat_template_kwargs = args.chat_template_kwargs or {}

3. Updated _get_eval_sampler method

File: trl/trainer/grpo_trainer.py

Modified the eval sampler to use num_generations_eval when available:

def _get_eval_sampler(self, eval_dataset) -> Sampler:
    # See _get_train_sampler for an explanation of the sampler.
    # If None, use num_generations for backward compatibility with previous config files
    num_gens = self.num_generations_eval or self.num_generations
    return RepeatSampler(
        data_source=eval_dataset,
        mini_repeat_count=num_gens,
        seed=self.args.seed,
    )

4. Updated vLLM server mode generation logic

File: trl/trainer/grpo_trainer.py (lines 1166-1173)

Modified to dynamically select the correct number of generations based on mode:

# Determine num_generations based on mode
mode = "train" if self.model.training else "eval"
num_gens = (
    self.num_generations_eval
    if mode == "eval" and self.num_generations_eval is not None
    else self.num_generations
)
ordered_set_of_prompts = all_prompts[::num_gens]

5. Updated prompt repetition logic in server mode

File: trl/trainer/grpo_trainer.py (lines 1223-1231)

Modified to repeat prompts the correct number of times:

# Determine repeat count based on mode
mode = "train" if self.model.training else "eval"
num_gens = (
    self.num_generations_eval
    if mode == "eval" and self.num_generations_eval is not None
    else self.num_generations
)
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_gens)]

6. Updated reward computation logic

File: trl/trainer/grpo_trainer.py (lines 1616-1621)

Modified to handle different generation counts for train/eval modes:

# If None, use num_generations for backward compatibility with previous config files
# Determine num_generations based on mode before computing grouped-wise rewards
mode = "train" if self.model.training else "eval"
num_gens = self.num_generations_eval or self.num_generations if mode == "eval" else self.num_generations
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, num_gens).mean(dim=1)

Summary of Modified Files

  1. trl/trainer/grpo_config.py: Added num_generations_eval parameter definition
  2. trl/trainer/grpo_trainer.py: Modified 4 locations:
    • Line 383: Store the parameter in __init__
    • Lines 760-768: Updated _get_eval_sampler method
    • Lines 1166-1173: Updated vLLM server mode generation
    • Lines 1223-1231: Updated prompt repetition logic
    • Lines 1616-1621: Updated reward computation

Backward Compatibility

Fully backward compatible: When num_generations_eval is None (default), the trainer falls back to using num_generations, ensuring existing configurations work without any changes.

Example Usage

args = GRPOConfig(
    num_generations=8,        # Use 8 generations during training
    num_generations_eval=2,   # Use only 2 generations during evaluation (4x faster)
    ...
)

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_func,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Benefits

  • Faster evaluation: Reduce evaluation time by using fewer generations
  • Cost savings: Lower computational costs during evaluation
  • Maintained quality: Training quality remains unchanged
  • Flexible: Users can choose the optimal trade-off between speed and evaluation accuracy

Who can review?

This PR is ready for review! Any community member is welcome to provide feedback.
A special thanks to @qgallouedec for considering this PR.
As my first open-source contribution, I'm excited to learn - please don't hesitate to suggest any enhancements!

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution, and welcome to the open-source community!
Regarding the PR — I don’t see a scenario where having a different number of generations during evaluation and training would be necessary. I’ll leave the PR open for now to see if the community expresses interest in this feature. If not, we can close it later.

@mingxuetian
Copy link
Contributor Author

mingxuetian commented Nov 6, 2025

Thanks for your contribution, and welcome to the open-source community! Regarding the PR — I don’t see a scenario where having a different number of generations during evaluation and training would be necessary. I’ll leave the PR open for now to see if the community expresses interest in this feature. If not, we can close it later.

Thank you @qgallouedec for the review!
I'd like to explain the practical motivation based on my training experience and issues #3539 @SnorkelerVigi and #3566 @CasanovaLLL
:

Community Need

Both issues #3539 and #3566 specifically request this feature because evaluation overhead was a major bottleneck in their training pipelines. @qgallouedec You also mentioned in your replies to these issues that this problem needs to be addressed, which confirms the necessity of this feature.

Why Different num_generations?

During Training:

  • Large num_generations (e.g., 16) is essential for accurate advantage estimation via group-wise reward normalization: mean_grouped_rewards = rewards.view(-1, num_gens).mean(dim=1)
  • More samples per prompt → more stable advantages → better training quality

During Evaluation:

  • We only need to monitor model performance, not compute gradients
  • Evaluation metrics are reliable with far fewer generations (e.g., 2)
  • Using training's large num_generations significantly slows down training

Real Impact

From my experiments:

  • Setup: num_generations=16 (train), num_generations_eval=2 (eval)
  • Result: ~87.5% faster evaluation (8x → 1x time)
  • Evaluation metrics remained statistically equivalent

I would greatly appreciate it if you could carefully consider this PR. Thank you!

issuse 3538_0 issuse 3539 issuse 3566

@mingxuetian
Copy link
Contributor Author

@qgallouedec, gentle ping. This PR directly addresses the problem you acknowledged in issue #3539, which is also a prerequisite for #3566. It provides a​ solution for the​ problem you confirmed needs fixing. A quick decision on this would be appreciated.

@qgallouedec
Copy link
Member

Thanks for the PR, we have limited bandwidth, please be patient

@mingxuetian
Copy link
Contributor Author

mingxuetian commented Nov 11, 2025

Thanks for the PR, we have limited bandwidth, please be patient

Thanks for the update. No problem, I truly understand the bandwidth constraints — appreciate you and the team's hard work. I'll stay patient, and please don't hesitate to reach out if you have any questions. Look forward to your review when time permits.

@aristizabal95
Copy link

aristizabal95 commented Nov 20, 2025

I second this. I would prefer having control over the number of generations for evaluation to cut down training time

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just waiting for the CI to pass, and we can merge

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Add num_generations_eval parameter for efficient evaluation Add num_generations_eval parameter for efficient evaluation Nov 24, 2025
@qgallouedec qgallouedec merged commit db4f6e5 into huggingface:main Nov 24, 2025
8 of 9 checks passed
@mingxuetian
Copy link
Contributor Author

mingxuetian commented Nov 24, 2025

LGTM, just waiting for the CI to pass, and we can merge

Thank you!

qgallouedec added a commit that referenced this pull request Dec 1, 2025
commit 07b4a84
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Dec 1 12:55:24 2025 -0700

    Silence experimental warnings when imported in the stable (#4606)

commit c55ef4b
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Dec 1 12:40:42 2025 -0700

    Update How-to guides (#4604)

commit c686d7d
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Dec 1 20:34:31 2025 +0100

    Raise FutureWarning for classes moved to experimental (#4605)

commit c7d172b
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Mon Dec 1 01:47:22 2025 -0800

    docs: Expand speeding up training guide with acceleration methods (#4428)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit f1dfef0
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Mon Dec 1 01:39:08 2025 -0800

    docs: Expand training customization examples (#4427)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit eb76389
Author: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Date:   Sun Nov 30 16:45:21 2025 +0100

    [GRPO] Sequence-level TIS & MIS (#4530)

commit 0726977
Author: xuanduy04 <65279552+xuanduy04@users.noreply.github.com>
Date:   Fri Nov 28 23:56:22 2025 +0700

    docs: Add Beyond the 80/20 Rule (2506.01939) to Paper Index (#4580)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 9731d08
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Nov 28 17:43:38 2025 +0100

    Revert "Hotfix CI with dev dependencies: xfail test_prepare_inputs_for_generation" (#4587)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 84a0bbc
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Nov 28 16:13:56 2025 +0100

    Fix 'generation_config' AttributeError (#4596)

commit f67c3f2
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Nov 28 15:46:02 2025 +0100

    Remove module-level imports of extra deps in experimental.judges (#4598)

commit cb5fdf9
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Thu Nov 27 11:08:26 2025 +0100

    Add missing require_bitsandbytes marker to CI tests (#4586)

commit 4a3b584
Author: juejuezi <juejuezi.git@foxmail.com>
Date:   Thu Nov 27 00:11:56 2025 +0800

    fix: use shift_labels for metrics when using CP or SP (#4579)

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit d2e4315
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Wed Nov 26 15:40:15 2025 +0100

    Revert hotfix Fall back to config.text_config._name_or_path (#4581)

commit 357e331
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 26 04:55:46 2025 -0700

    Move tests for GSPOTokenTrainer to experimental (#4572)

commit a59f2cf
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 26 04:50:44 2025 -0700

    Move `WinRateCallback` to experimental (#4558)

    Co-authored-by: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit cf431db
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 26 04:11:04 2025 -0700

    Fix PPO example (#4556)

commit cac9f1d
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Tue Nov 25 21:27:58 2025 +0000

    Fix Replay Buffer docs. (#4574)

commit 547d924
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Nov 25 09:34:22 2025 -0700

    Add `shuffle_dataset` option to `SFTTrainer` (#4564)

commit b01f8ca
Author: iliasmerigh <91261122+iliasmerigh@users.noreply.github.com>
Date:   Tue Nov 25 17:33:14 2025 +0100

    Fix typo in GRPO description in README (#4573)

commit 7856d3b
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Nov 25 09:32:39 2025 -0700

    Fix vLLM sleep mode: add collective RPC call to reload weights in vLLM wake-up process (#4571)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

commit 64d089e
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Tue Nov 25 14:39:40 2025 +0100

    Reasoning reward (#4563)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 3b7d0e4
Author: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Date:   Tue Nov 25 04:48:06 2025 +0000

    Remove Online DPO from stable trainers section in documentation

commit 6f3a452
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Nov 24 08:11:49 2025 -0700

    Reorder documentation TOC to surface key trainer sections (#4565)

commit 46af266
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Mon Nov 24 02:39:25 2025 -0800

    docs: Rewrite PEFT integration guide with comprehensive examples (#4421)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit db4f6e5
Author: mingxuetian <108911581+mingxuetian@users.noreply.github.com>
Date:   Mon Nov 24 09:51:42 2025 +0800

    Add `num_generations_eval` parameter for efficient evaluation (#4458)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 07f3c95
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Sun Nov 23 17:33:36 2025 -0800

    Move OnlineDPOTrainer to experimental module (#4473)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 4cb1a25
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Sat Nov 22 23:31:29 2025 +0100

    [SFT] Log mean token accuracy from Liger kernel (#4302)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 468b9d4
Author: Susant <acharysusant@gmail.com>
Date:   Sun Nov 23 03:40:32 2025 +0530

    docs: add KTO (2402.01306) to Paper Index + link ref to KTOTrainer (#4440)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 9bc6206
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Fri Nov 21 17:34:50 2025 -0800

    Move PRMTrainer to trl.experimental.prm (#4483)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit f7ac974
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Fri Nov 21 16:01:04 2025 +0100

    Update OpenEnv guide with new notebook (#4555)

commit c0de042
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Fri Nov 21 15:40:25 2025 +0100

    Add GRPO Wordle OpenEnv Colab (#4542)

commit 9f8ef40
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 20 22:36:31 2025 -0800

    [ORPO] Move ORPOTrainer to experimental (#4480)

commit 3bb5d76
Author: Jen Wei <45276133+JenWei0312@users.noreply.github.com>
Date:   Thu Nov 20 18:53:10 2025 -0700

    fix+docs: `device_map=None` for DeepSpeed and add ZeRO paper (1910.02054) to Paper Index (#4551)

commit 375b3eb
Author: Jonny Li <jonny_li@live.ca>
Date:   Thu Nov 20 19:42:45 2025 -0500

    Add target_parameters to LoraConfig (#4536)

commit 237900d
Author: Kristian Schwethelm <47533587+kschwethelm@users.noreply.github.com>
Date:   Thu Nov 20 23:03:20 2025 +0100

    Fix bug with VLM processors in prompt-completion completion text-only training (#4553)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 52ed4df
Author: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Date:   Thu Nov 20 21:41:23 2025 +0000

    Fix style OpenEnv example

commit a263946
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Thu Nov 20 14:44:15 2025 +0100

    Update OpenEnv guide with latest details (#4552)

    Co-authored-by: burtenshaw <ben.burtenshaw@gmail.com>

commit 1a9ff52
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Nov 19 15:34:25 2025 +0100

    [OpenEnv] browsergym example script (#4539)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 6cbcd94
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Wed Nov 19 14:39:44 2025 +0100

    Update OpenEnv example scripts (#4547)

commit 8510589
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Wed Nov 19 14:39:20 2025 +0100

    Add OpenEnv Script examples to docs (#4533)

commit e622196
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Nov 17 03:12:30 2025 -0700

    [Doc] Drop dummy reward and dataset for DeepMath-103K and accuracy reward (#4524)

commit 1b1242c
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Fri Nov 14 20:51:41 2025 +0100

    [OpenEnv] add vllm colocate mode to openenv scripts (#4510)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit f39d18a
Author: Fabio Milentiansen Sim <sim.fabio.fms@gmail.com>
Date:   Fri Nov 14 23:39:02 2025 +0700

    fix(GOLDTrainer): Resolve incorrect attribute access and VLLMClient.generate() output type (#4526)

commit d45eaab
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Fri Nov 14 12:12:09 2025 +0100

    Add vLLM quantization option for colocate (#4496)

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit a91d4b3
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Fri Nov 14 02:19:08 2025 +0100

    Prevent upcasting norm layers in `prepare_model_for_kbit_training` (#4457)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 121318e
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 13 17:13:16 2025 -0800

    docs: Extend CLI basic usage examples to all supported CLIs (#4425)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 7918320
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Nov 13 13:20:52 2025 -0700

    Remove test trainer args (#4517)

commit 102dc41
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Nov 13 12:36:43 2025 -0700

    Rename `flash-attn` to `flash-attn2` (#4514)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 5de62b0
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Nov 13 12:05:48 2025 -0700

    Add step time metric to GRPO Trainer for performance tracking (#4516)

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

commit f1e6377
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 13 11:01:19 2025 -0800

    Move PPOTrainer to trl.experimental.ppo (#4482)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 01f497e
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 13 10:14:58 2025 -0800

    Move NashMDTrainer to experimental module (#4477)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit b6c838a
Author: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Date:   Thu Nov 13 16:53:26 2025 +0000

    `aws-general-8-plus` runner for Docker build

commit ed5c7bb
Author: YangKai0616 <kai.yang@intel.com>
Date:   Fri Nov 14 00:42:48 2025 +0800

    [Bug Fix] OnlineDPOTrainer with vLLM Server Mode (#4500)

commit ded9bc6
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Thu Nov 13 17:33:59 2025 +0100

    Fix Docker images for Liger (#4522)

commit fd04760
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Thu Nov 13 11:31:10 2025 +0000

    Paper Index: Change `num_completions` to `num_generations` (#4515)

commit b7918c0
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Wed Nov 12 20:35:44 2025 -0800

    Move GKDTrainer to experimental module (#4474)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 07b5011
Author: Tamoghno Kandar <55907205+tamoghnokandar@users.noreply.github.com>
Date:   Wed Nov 12 20:07:33 2025 -0800

    Replace flash attention2 with kernels-community/flash-attn2 (#4426)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 7a57fd4
Author: Yuxian Gu <guyx21@mails.tsinghua.edu.cn>
Date:   Thu Nov 13 11:16:20 2025 +0800

    MiniLLM: Fix arguments in config & add to documentation index (#4518)

commit a145eaf
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Wed Nov 12 16:35:46 2025 -0800

    refactor: Move CPOTrainer to experimental module (#4470)

commit d2dc717
Author: Taha Yassine <40228615+taha-yassine@users.noreply.github.com>
Date:   Thu Nov 13 00:56:47 2025 +0100

    Replace `wandb_log_unique_prompts` with `log_unique_prompts` (#4508)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 799b39b
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 12 16:21:05 2025 -0700

    `device_map` and `dtype` to `"auto"` by default (#4509)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit a6a2beb
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 12 09:42:31 2025 -0700

    Add temporary workaround for `lr_scheduler_kwargs` dtype issue in Transformers 4.57.0 (#4513)

commit 346701a
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Wed Nov 12 17:42:18 2025 +0100

    Replace accelerate logging with stdlib in CLI (#4512)

commit 4db63af
Author: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Date:   Wed Nov 12 02:19:51 2025 +0000

    Fix GRPO unsqueeze advantages

commit ecb2811
Author: Yuxian Gu <guyx21@mails.tsinghua.edu.cn>
Date:   Wed Nov 12 10:17:22 2025 +0800

    Add MiniLLM Trainer (#4504)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 89e4688
Author: Taha Yassine <40228615+taha-yassine@users.noreply.github.com>
Date:   Tue Nov 11 20:36:23 2025 +0100

    Add support for images inside tables with Trackio completions logging (#4505)

commit 2d3279c
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Tue Nov 11 19:22:25 2025 +0100

    Tweak description for vLLM sleep mode (#4506)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 02a3477
Author: Luke Hinds <lukehinds@gmail.com>
Date:   Mon Nov 10 16:41:51 2025 +0000

    Fix link to OpenEnv docs (#4502)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit aaed6c1
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sat Nov 8 08:20:48 2025 -0700

    Consistency regarding relative imports (#4498)

commit 20760ba
Author: burtenshaw <ben.burtenshaw@gmail.com>
Date:   Fri Nov 7 10:50:50 2025 +0100

    [DOCS] update and fix openenv (#4490)

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 64cfca4
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 6 22:47:04 2025 -0800

    Move judges to experimental submodule (#4439)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 97ca1a2
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Fri Nov 7 00:20:15 2025 +0000

    Fix bugs in CISPO conditions (#4499)

commit ffb3dd5
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 6 16:03:00 2025 -0800

    docs: Add PEFT subsection to reducing memory usage guide (#4430)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 43b6541
Author: SolarWindRider <31797478+SolarWindRider@users.noreply.github.com>
Date:   Fri Nov 7 06:55:34 2025 +0800

    Support completion bootstrap for VLM in GRPO/RLOO (#4452)

    Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 642b721
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Thu Nov 6 22:33:00 2025 +0000

    ScaleRL: Add CISPO Loss (#4495)

commit 32e9c9f
Author: Ishita Bhattacharyya <139248026+ishitab02@users.noreply.github.com>
Date:   Fri Nov 7 03:37:43 2025 +0530

    ⛴️ Add kernels to Docker images (#4445)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 1bcfc50
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 6 13:40:12 2025 -0800

    Move XPOTrainer to trl.experimental.xpo (#4485)

    Co-authored-by: Invidia19 <54266187+Invidia19@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 37942bc
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Thu Nov 6 21:32:03 2025 +0000

    Buffer samples based on group level stds. (#4492)

commit 66cd02a
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Thu Nov 6 20:58:25 2025 +0100

    Add tiny model Qwen3VLForConditionalGeneration to CI (#4494)

commit 32febb4
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Thu Nov 6 18:21:56 2025 +0100

    Add LFM2 to SFT notebook examples (#4455)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GRPO per_device_eval_batch_size can't be set as 1, when there is only 1 GPU GRPOTrainer - Repeat Sampler - _get_eval_sampler

4 participants

Comments