From f614e4ab99a532681df751289fb65ed39650c9a2 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 06:24:36 -0800 Subject: [PATCH 01/10] squash (remove dependency update) Signed-off-by: Yuki Huang --- examples/configs/gdpo_math_1B.yaml | 51 ++++ examples/prompts/gsm8k.txt | 17 ++ examples/run_gdpo_gsm8k.py | 258 +++++++++++++++++ nemo_rl/algorithms/advantage_estimator.py | 75 +++++ nemo_rl/algorithms/grpo.py | 143 +++++++-- .../datasets/response_datasets/__init__.py | 3 + .../data/datasets/response_datasets/gsm8k.py | 82 ++++++ nemo_rl/data/processors.py | 70 +++++ .../ray_actor_environment_registry.py | 1 + nemo_rl/environments/interfaces.py | 2 +- nemo_rl/environments/math_environment.py | 272 ++++++++++++++++++ nemo_rl/experience/rollouts.py | 157 +++++++--- tests/functional/L1_Functional_Tests_GPU.sh | 1 + tests/functional/gdpo.sh | 46 +++ 14 files changed, 1102 insertions(+), 76 deletions(-) create mode 100644 examples/configs/gdpo_math_1B.yaml create mode 100644 examples/prompts/gsm8k.txt create mode 100644 examples/run_gdpo_gsm8k.py create mode 100644 nemo_rl/data/datasets/response_datasets/gsm8k.py create mode 100644 tests/functional/gdpo.sh diff --git a/examples/configs/gdpo_math_1B.yaml b/examples/configs/gdpo_math_1B.yaml new file mode 100644 index 0000000000..3536e608f8 --- /dev/null +++ b/examples/configs/gdpo_math_1B.yaml @@ -0,0 +1,51 @@ +# GDPO: inherits from grpo_math_1B.yaml and overrides only what differs. +defaults: grpo_math_1B.yaml + +grpo: + adv_estimator: + name: "gdpo" + normalize_rewards: true + use_leave_one_out_baseline: false + +checkpointing: + checkpoint_dir: "results/gdpo" + +policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + logprob_batch_size: 4 + max_total_sequence_length: 1024 + megatron_cfg: + optimizer: + weight_decay: 0.0 + scheduler: + lr_decay_style: "cosine" + lr_warmup_iters: 10 + +# GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default. +data: + _override_: true + max_input_seq_length: ${policy.max_total_sequence_length} + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: "examples/prompts/gsm8k.txt" + shuffle: true + num_workers: 1 + processor: "math_gdpo_data_processor" + env_name: "math" + dataset_name: "gsm8k" + +env: + math: + num_workers: 8 + math_verify_impl: "hf_math_verify" + +logger: + wandb_enabled: true + wandb: + project: "gdpo-dev" + name: "gdpo-dev-logger" + swanlab: + project: "gdpo-dev" + name: "gdpo-dev-logger" + mlflow: + experiment_name: "gdpo-dev" + run_name: "gdpo-dev-logger" diff --git a/examples/prompts/gsm8k.txt b/examples/prompts/gsm8k.txt new file mode 100644 index 0000000000..3c31977100 --- /dev/null +++ b/examples/prompts/gsm8k.txt @@ -0,0 +1,17 @@ +You are a helpful AI assistant. + +For every request, you should carefully think through the math problem step by step, then provide the final answer in integer format. + +Steps for Each Request: +1. Think: Provide detailed, step-by-step reasoning, calculations, or derivations. +2. Produce Final Answer: After step-by-step reasoning, output the final answer in integer format. + +Output Format: +Your thoughts and reasoning +Final answer in integer format + +Important Notes: +1. You must include your reasoning steps inside . +2. You must always output the Final Answer within after the reasoning steps is done. +3. You should consistently work through the solution step by step before giving the final answer. +4. The final answer can only be an integer. \ No newline at end of file diff --git a/examples/run_gdpo_gsm8k.py b/examples/run_gdpo_gsm8k.py new file mode 100644 index 0000000000..14bc55cf3e --- /dev/null +++ b/examples/run_gdpo_gsm8k.py @@ -0,0 +1,258 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +from collections import defaultdict +from typing import Any, Optional + +from omegaconf import OmegaConf +from transformers import PreTrainedTokenizerBase + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset +from nemo_rl.data.interfaces import ( + TaskDataProcessFnCallable, + TaskDataSpec, +) +from nemo_rl.data.processors import math_gdpo_data_processor +from nemo_rl.distributed.ray_actor_environment_registry import ( + get_actor_python_env, +) +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.environments.math_environment import MathMultiRewardEnvironment +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +# =============================================================================== +# Math Data Processor +# =============================================================================== +TokenizerType = PreTrainedTokenizerBase + + +def setup_data( + tokenizer: TokenizerType, + data_config: DataConfig, + env_configs: dict[str, Any], + seed: int, +) -> tuple[ + AllTaskProcessedDataset, + Optional[AllTaskProcessedDataset], + dict[str, EnvironmentInterface], + dict[str, EnvironmentInterface], +]: + print("\nā–¶ Setting up data...") + math_task_spec = TaskDataSpec( + task_name="math", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + + # load dataset + data: Any = load_response_dataset(data_config) + task_name = ( + data.task_name if hasattr(data, "task_name") else data.task_spec.task_name + ) + + # data processor + task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( + defaultdict(lambda: (math_task_spec, math_gdpo_data_processor)) + ) + task_data_processors[task_name] = (math_task_spec, math_gdpo_data_processor) + + # setup math environment + math_env = MathMultiRewardEnvironment.options( # type: ignore # it's wrapped with ray.remote + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathMultiRewardEnvironment" + ), + "env_vars": dict(os.environ), # Pass thru all user environment variables + } + ).remote(env_configs["math"]) + + dataset = AllTaskProcessedDataset( + data.dataset, + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset: Optional[AllTaskProcessedDataset] = None + if data.val_dataset is not None: + val_dataset = AllTaskProcessedDataset( + data.val_dataset, + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) + task_to_env[task_name] = math_env + return dataset, val_dataset, task_to_env, task_to_env + + +def main() -> None: + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "gdpo_math_1B.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"šŸ“Š Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + assert config["policy"]["generation"] is not None, ( + "A generation config is required for GRPO" + ) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # setup data + ( + dataset, + val_dataset, + task_to_env, + val_task_to_env, + ) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"]) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + # Check if async mode is enabled + if "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]: + # Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features) + unsupported_features = [ + "use_dynamic_sampling", + "reward_scaling", + "reward_shaping", + ] + + for feature in unsupported_features: + if feature not in config["grpo"]: + continue + + if feature == "use_dynamic_sampling": + if config["grpo"][feature]: + raise NotImplementedError( + f"{feature} is not supported with async GRPO" + ) + else: + if config["grpo"][feature]["enabled"]: + raise NotImplementedError( + f"{feature} is not supported with async GRPO" + ) + + from nemo_rl.algorithms.grpo import async_grpo_train + + print("šŸš€ Running async GRPO training") + + async_config = config["grpo"]["async_grpo"] + # Run async GRPO training + async_grpo_train( + policy=policy, + policy_generation=policy_generation, + dataloader=dataloader, + val_dataloader=val_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + task_to_env=task_to_env, + val_task_to_env=val_task_to_env, + logger=logger, + checkpointer=checkpointer, + grpo_save_state=grpo_state, + master_config=master_config, + max_trajectory_age_steps=async_config["max_trajectory_age_steps"], + ) + else: + print("šŸš€ Running synchronous GRPO training") + + # Run standard GRPO training + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py index 6d14d8637a..d8e9767a4b 100644 --- a/nemo_rl/algorithms/advantage_estimator.py +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -16,17 +16,25 @@ This module provides different advantage estimation strategies: - GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline +- GDPOAdvantageEstimator: Multi-reward GDPO (per-component baselines, sum then normalize) - ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward Reference papers: - ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ - Reinforce++: https://arxiv.org/abs/2501.03262 """ +import re import torch from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl +def _get_reward_component_keys(batch) -> list: + """Return batch keys that are reward components (reward1, reward2, ...) in sorted order.""" + keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))] + return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group())) + + class GRPOAdvantageEstimator: """GRPO-style advantage estimator with leave-one-out baseline. @@ -69,6 +77,73 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): return advantages.expand(mask.shape) +class GDPOAdvantageEstimator: + """GDPO-style advantage estimator with leave-one-out baseline. + + Note: GDPO computes advantages for each reward separately over all responses for each prompt. + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] + self.normalize_rewards = estimator_config["normalize_rewards"] + + def compute_advantage(self, repeated_batch, mask, **kwargs): + """Compute GDPO advantages. + + Args: + repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys. + mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used only for expanding advantages to token-level shape. + **kwargs: Additional arguments (unused). + + Returns: + Advantages tensor of shape [batch_size, seq_len]. + """ + reward_component_keys = _get_reward_component_keys(repeated_batch) + if not reward_component_keys: + raise ValueError( + "GDPOAdvantageEstimator requires reward component keys (reward1, reward2, ...) in repeated_batch" + ) + current_input_ids = repeated_batch["_input_ids_for_baseline"] + valid = torch.ones_like( + repeated_batch[reward_component_keys[0]] + ) + leave_one_out = self.use_leave_one_out_baseline + assert current_input_ids.shape[0] == valid.shape[0], ( + "_input_ids_for_baseline must match reward batch size after dynamic_sampling; " + f"got {current_input_ids.shape[0]} vs {valid.shape[0]}" + ) + advantage_parts = [] + for key in reward_component_keys: + r = repeated_batch[key] + base, std_k = calculate_baseline_and_std_per_prompt( + current_input_ids, + r, + valid, + leave_one_out_baseline=leave_one_out, + ) + adv_k = (r - base).unsqueeze(-1) + if self.normalize_rewards: + + epsilon = 1e-6 + non_zero_std_mask = std_k > 0 + adv_k[non_zero_std_mask] = adv_k[non_zero_std_mask] / ( + std_k.unsqueeze(-1)[non_zero_std_mask] + epsilon + ) + + advantage_parts.append(adv_k) + + advantages = sum(advantage_parts) + # Normalize combined advantage to zero mean and unit std + adv_std = advantages.std() + if adv_std > 0: + advantages = (advantages - advantages.mean()) / adv_std + else: + advantages = advantages - advantages.mean() + + return advantages.expand(mask.shape) + + class ReinforcePlusPlusAdvantageEstimator: """Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index c060a05a50..027bfb3262 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -13,6 +13,7 @@ # limitations under the License. import gc import os +import re import time import warnings from concurrent.futures import ThreadPoolExecutor @@ -29,6 +30,7 @@ from nemo_rl.algorithms.advantage_estimator import ( GRPOAdvantageEstimator, + GDPOAdvantageEstimator, ReinforcePlusPlusAdvantageEstimator, ) from nemo_rl.algorithms.loss import ( @@ -121,9 +123,9 @@ class AsyncGRPOConfig(TypedDict): class AdvEstimatorConfig(TypedDict): - """Configuration for advantage estimator (GRPO or Reinforce++).""" + """Configuration for advantage estimator (GRPO, GDPO, or Reinforce++).""" - name: str # "grpo" or "reinforce_plus_plus" + name: str # "grpo", "gdpo", or "reinforce_plus_plus" # GRPO specific normalize_rewards: NotRequired[bool] use_leave_one_out_baseline: NotRequired[bool] @@ -934,6 +936,15 @@ def dynamic_sampling( return batch_to_return, is_batch_complete, batch_cache, dynamic_sampling_metrics +def _get_reward_component_keys(batch: BatchedDataDict[Any]) -> list[str]: + """Return batch keys that are reward components (reward1, reward2, ...) in sorted order. + + Enables environments to expose any number of rewards without code changes elsewhere. + """ + keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))] + return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group())) + + def scale_rewards( repeated_batch: BatchedDataDict[DatumSpec], reward_scaling_cfg: RewardScalingConfig ) -> BatchedDataDict[DatumSpec]: @@ -966,11 +977,19 @@ def scale_rewards( ) # Clamp and scale + def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: + r = torch.clamp(reward_tensor, min=source_min, max=source_max) + return target_min + (r - source_min) / ( + source_max - source_min + ) * (target_max - target_min) + rewards = torch.clamp(rewards, min=source_min, max=source_max) scaled_rewards = target_min + (rewards - source_min) / ( source_max - source_min ) * (target_max - target_min) repeated_batch["total_reward"] = scaled_rewards + for key in _get_reward_component_keys(repeated_batch): + repeated_batch[key] = _scale(repeated_batch[key]) return repeated_batch @@ -1024,24 +1043,25 @@ def _should_log_nemo_gym_responses(master_config: MasterConfig) -> bool: return should_log_nemo_gym_responses -def _create_advantage_estimator(master_config: MasterConfig): +def _create_advantage_estimator( + master_config: MasterConfig, use_multi_reward_advantages: bool = False +): """Create and return an advantage estimator based on configuration. Args: master_config: The master configuration dictionary. + use_multi_reward_advantages: If True and name is "gdpo", use GDPO. + When False and name is "gdpo", use GRPO (single-reward fallback). Returns: - An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator). + An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus). Raises: ValueError: If the advantage estimator name is not recognized. """ grpo_config = master_config["grpo"] loss_config = master_config["loss_fn"] - # Provide backward-compatible defaults when adv_estimator is not in config. - # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline - # which older configs still use. adv_estimator_config = grpo_config.get( "adv_estimator", { @@ -1055,7 +1075,15 @@ def _create_advantage_estimator(master_config: MasterConfig): ) adv_estimator_name = adv_estimator_config["name"] - if adv_estimator_name == "grpo": + # GDPO only when we have multi-reward data; otherwise "gdpo" config uses GRPO + if use_multi_reward_advantages and adv_estimator_name == "gdpo": + adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config) + print(" āœ“ Using GDPO advantage estimator (multi-reward)") + elif adv_estimator_name == "gdpo": + # GDPO config but single-reward batch: use GRPO + adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) + print(" āœ“ Using GRPO advantage estimator (gdpo config, single-reward)") + elif adv_estimator_name == "grpo": adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) print(" āœ“ Using GRPO advantage estimator") elif adv_estimator_name == "reinforce_plus_plus": @@ -1364,9 +1392,6 @@ def grpo_train( val_period = master_config["grpo"]["val_period"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - # Initialize advantage estimator - adv_estimator = _create_advantage_estimator(master_config) - # Run validation at the start if configured # TODO: Add validation with kv scales if needed if val_at_start and current_step == 0: @@ -1587,9 +1612,25 @@ def grpo_train( # Calculate rewards & advantages memory_tracker.snapshot_start_of_stage("Processing rewards", dir()) print("ā–¶ Processing rewards...,", flush=True) + # GDPO with timer.time("reward_calculation"): - # Extract rewards from final_batch + # Use GDPO when adv_estimator name is "gdpo" and batch has + # multiple reward components (reward1, reward2, ...). rewards = repeated_batch["total_reward"] + adv_estimator_config = master_config["grpo"].get( + "adv_estimator", {"name": "grpo"} + ) + adv_estimator_name = adv_estimator_config.get("name", "grpo") + reward_component_keys = _get_reward_component_keys(repeated_batch) + use_multi_reward_advantages = ( + adv_estimator_name == "gdpo" + and len(reward_component_keys) >= 2 + ) + + # Store input_ids in batch so that after dynamic_sampling it stays aligned with + # the (possibly filtered) batch: select_indices / from_batches / slice all + # apply to this key, so per-reward baselines use the same prompts as reward components. + repeated_batch["_input_ids_for_baseline"] = input_ids print("ā–¶ Computing advantages...", flush=True) if master_config["grpo"].get("calculate_advantages_on_gpu"): @@ -1644,10 +1685,16 @@ def grpo_train( # If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch. if not is_batch_complete: continue + + # Create advantage estimator for this batch (GDPO when multi-reward, else GRPO/Reinforce++) + adv_estimator = _create_advantage_estimator( + master_config, + use_multi_reward_advantages=use_multi_reward_advantages, + ) + gen_step_metrics = {} if hasattr(policy_generation, "get_step_metrics"): gen_step_metrics = policy_generation.get_step_metrics() - advantages = (rewards - baseline).unsqueeze(-1) # Save baseline for logging (before deletion) baseline_for_log = baseline.clone() @@ -1775,13 +1822,29 @@ def grpo_train( sample_mask = train_data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) + + + if use_multi_reward_advantages: + # GDPO adv_estimation + train_data["advantages"] = adv_estimator.compute_advantage( + repeated_batch=repeated_batch, + mask=mask, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) + + else: + + train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, + mask=mask, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) + + + del prompt_ids_for_adv # Log rewards and advantages information @@ -2431,9 +2494,6 @@ def async_grpo_train( val_at_end = master_config["grpo"]["val_at_end"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - # Initialize advantage estimator - adv_estimator = _create_advantage_estimator(master_config) - assert not colocated_inference, ( "Colocated inference is not supported for async GRPO. Please use non-colocated inference." ) @@ -2724,6 +2784,23 @@ def async_grpo_train( del prompt_batched_flat rewards = repeated_batch["total_reward"] + adv_estimator_config = master_config["grpo"].get( + "adv_estimator", {"name": "grpo"} + ) + adv_estimator_name = adv_estimator_config.get("name", "grpo") + reward_component_keys = _get_reward_component_keys(repeated_batch) + use_multi_reward_advantages = ( + adv_estimator_name == "gdpo" + and len(reward_component_keys) >= 2 + ) + if use_multi_reward_advantages: + repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv + + # Create advantage estimator (GDPO when name is "gdpo" and batch has multi-reward) + adv_estimator = _create_advantage_estimator( + master_config, + use_multi_reward_advantages=use_multi_reward_advantages, + ) print( f" šŸ“Š Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" @@ -2806,13 +2883,19 @@ def async_grpo_train( sample_mask = train_data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) + if use_multi_reward_advantages: + train_data["advantages"] = adv_estimator.compute_advantage( + repeated_batch=repeated_batch, + mask=mask, + ) + else: + train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, + mask=mask, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) del prompt_ids_for_adv # Log advantages stats diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index eb48bb5204..d53a422bbb 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -20,6 +20,7 @@ DAPOMath17KDataset, DAPOMathAIME2024Dataset, ) +from nemo_rl.data.datasets.response_datasets.gsm8k import GSM8KDataset from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset from nemo_rl.data.datasets.response_datasets.general_conversations_dataset import ( GeneralConversationsJsonlDataset, @@ -55,6 +56,7 @@ "refcoco": RefCOCODataset, "squad": SquadDataset, "tulu3_sft_mixture": Tulu3SftMixtureDataset, + "gsm8k": GSM8KDataset, # load from local JSONL file or HuggingFace "openai_format": OpenAIFormatDataset, "NemoGymDataset": NemoGymDataset, @@ -94,6 +96,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig): "GeneralConversationsJsonlDataset", "DAPOMath17KDataset", "DAPOMathAIME2024Dataset", + "GSM8KDataset", "DeepScalerDataset", "Geometry3KDataset", "HelpSteer3Dataset", diff --git a/nemo_rl/data/datasets/response_datasets/gsm8k.py b/nemo_rl/data/datasets/response_datasets/gsm8k.py new file mode 100644 index 0000000000..970e8a076e --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/gsm8k.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any + +from datasets import load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +def _load_system_prompt(system_prompt_file: str | None) -> str: + """Load system prompt from file. Returns empty string if path is None or missing.""" + if not system_prompt_file: + return "" + if os.path.exists(system_prompt_file): + with open(system_prompt_file, "r", encoding="utf-8") as f: + return f.read() + raise FileNotFoundError(f"System prompt file {system_prompt_file!r} not found.") + + +def _extract_hash_answer(text: str) -> str | None: + if "####" not in text: + return None + return text.split("####")[1].strip() + + +class GSM8KDataset(RawDataset): + """Simple wrapper around the GSM8K dataset with train and validation splits. + + Args: + seed: Random seed for shuffling the training set (default 42). + system_prompt_file: Optional path to a text file containing the system prompt + (e.g. examples/prompts/gsm8k.txt). If not provided, system prompt is empty. + """ + + def __init__( + self, + seed: int = 42, + system_prompt_file: str | None = None, + **kwargs, + ) -> None: + self.task_name = "gsm8k" + self._system_prompt = _load_system_prompt(system_prompt_file) + + # Load from HuggingFace + train_ds = load_dataset("openai/gsm8k", "main")["train"] + val_ds = load_dataset("openai/gsm8k", "main")["test"] + + # Shuffle training with seed + train_ds = train_ds.shuffle(seed=seed) + + # Format the datasets + self.dataset = train_ds.map( + self.format_data, + remove_columns=train_ds.column_names, + ) + self.val_dataset = val_ds.map( + self.format_data, + remove_columns=val_ds.column_names, + ) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + {"role": "system", "content": self._system_prompt}, + {"role": "user", "content": data["question"]}, + {"role": "assistant", "content": _extract_hash_answer(data["answer"])}, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 52ac9bf67d..57a88ba644 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -381,6 +381,75 @@ def math_data_processor( return output +def math_gdpo_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment.""" + + user_message = datum_dict["messages"] + + # print(f"user_message {user_message}") + + problem = user_message[1]["content"] + + extra_env_info = {"ground_truth": user_message[2]["content"]} + + message_log: LLMMessageLogType = [] + + + system_message = { + "role": "system", + "content": user_message[0]["content"] + + } + + user_message = { + "role": "user", + "content": problem, + } + + + message: list[str] = tokenizer.apply_chat_template( # type: ignore + [system_message, user_message], + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + + user_message["token_ids"] = tokenizer( + message, + return_tensors="pt", + add_special_tokens=False, + )["input_ids"][0] + user_message["content"] = message + message_log.append(user_message) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for chat_message in message_log: + chat_message["token_ids"] = chat_message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict["task_name"], + } + return output + + def math_hf_data_processor( datum_dict: dict[str, Any], task_data_spec: TaskDataSpec, @@ -698,6 +767,7 @@ def nemo_gym_data_processor( "helpsteer3_data_processor": helpsteer3_data_processor, "math_data_processor": math_data_processor, "math_hf_data_processor": math_hf_data_processor, + "math_gdpo_data_processor": math_gdpo_data_processor, "multichoice_qa_processor": multichoice_qa_processor, "sft_processor": sft_processor, "vlm_hf_data_processor": vlm_hf_data_processor, diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 90d52fe76e..3f02acb4e1 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -35,6 +35,7 @@ "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL, "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE, "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, + "nemo_rl.environments.math_environment.MathMultiRewardEnvironment" : PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.vlm_environment.VLMEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM, diff --git a/nemo_rl/environments/interfaces.py b/nemo_rl/environments/interfaces.py index b869c32df7..a8167581e5 100644 --- a/nemo_rl/environments/interfaces.py +++ b/nemo_rl/environments/interfaces.py @@ -44,7 +44,7 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]): observations: list[dict[str, str]] metadata: list[MetadataT] next_stop_strings: list[list[str] | None] | list[None] - rewards: Tensor + rewards: Tensor ## This could be of different shape terminateds: Tensor answers: list[str | None] | None diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index 8de2da805a..9f0acf655f 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -230,6 +230,120 @@ def verify( return results +@ray.remote # pragma: no cover +class HFMultiRewardVerifyWorker: + def __init__(self) -> None: + logging.getLogger("math_multi_reward_verify").setLevel(logging.CRITICAL) + + # Use Latex and plain math extraction from predictions + # https://github.com/huggingface/Math-Verify?tab=readme-ov-file#extraction-targets + self.verify_func = math_metric( + gold_extraction_target=(LatexExtractionConfig(),), + pred_extraction_target=( + ExprExtractionConfig(), + LatexExtractionConfig(), + ), + ) + + def verify( + self, + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + **kwargs, + ) -> Union[list[float], tuple[list[float], list[str | None]]]: + """Verify the correctness of the predicted responses against the ground truth. + + Args: + pred_responses: list[str]. The predicted responses from the LLM. + ground_truths: list[str]. The ground truth responses. + + Returns: + Union[list[float], tuple[list[float], list[str | None]]]. + If return_extracted_answer is False, returns only the scores. + If return_extracted_answer is True, returns (scores, extracted_answers). + """ + def extract_xml_answer(text: str) -> str: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + + def correctness_reward_func(completions, answer, **kwargs) -> list[float]: + extracted_responses = [extract_xml_answer(r) for r in completions] + return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + + def int_reward_func(completions, **kwargs) -> list[float]: + extracted_responses = [extract_xml_answer(r) for r in completions] + return [1.0 if r.isdigit() else 0.0 for r in extracted_responses] + + def format_reward_func(completions, **kwargs) -> list[float]: + """Reward function that checks if the completion has a specific format.""" + rewards = [] + for response in completions: + + pattern = r"^.*?\n.*?$" + + if re.search(pattern, response, re.DOTALL) and response.count("") == 1 and response.count("") == 1: + rewards.append(1.0) + else: + rewards.append(0.0) + + return rewards + + number_of_rewards = 3 + results = [[] for i in range(number_of_rewards)] + extracted_answers: list[str | None] = [] + + for response, ground_truth in zip(pred_responses, ground_truths): + + try: + # with _mute_output(): + math_verify_impl = kwargs.get("math_verify_impl", "hf_math_verify") + if math_verify_impl == "hf_math_verify": + cor_reward = correctness_reward_func([response],[ground_truth]) + int_reward = int_reward_func([response]) + format_reward = format_reward_func([response]) + extracted_answer = extract_xml_answer(response) + else: + raise ValueError( + f"Unknown math_verify_impl: {math_verify_impl}. Expected 'hf_math_verify'" + ) + + results[0].extend(cor_reward) + results[1].extend(int_reward) + results[2].extend(format_reward) + + if return_extracted_answer: + # Make sure the extracted answer is not None and is a list of two elements + assert extracted_answer is not None + assert len(extracted_answer) == 2 + extracted_gold, extracted_prediction = extracted_answer + # Get the extracted answer with the same logic as in the HFVerifyWorker + for pred in extracted_prediction: + if any(grader.verify(gold, pred) for gold in extracted_gold): + extracted_answers.append(pred) + break + else: + # If no match is found, means all answers are incorrect, just use the first prediction + extracted_answers.append(extracted_prediction[0][0]) + + # It's possible to emit a TimeoutException and that wouldn't be caught since + # it actually subclasses from BaseException and math-verify itself does not + # to catch it. + except (Exception, TimeoutException): + results[0].append(0.0) + results[1].append(0.0) + results[2].append(0.0) + extracted_answers.append(None) + + if return_extracted_answer: + return results, extracted_answers + else: + return results + # return results --> [[0,1,0], [0,2,0], .........] + + + class MathEnvironmentMetadata(TypedDict): ground_truth: str extracted_answer: str | None @@ -391,3 +505,161 @@ def global_post_process_and_metrics( } return batch, metrics + + + +@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover +class MathMultiRewardEnvironment(EnvironmentInterface[MathEnvironmentMetadata]): + def __init__(self, cfg: MathEnvConfig): + self.cfg = cfg + self.num_workers = cfg["num_workers"] + # TODO: split out this environment since it's doing more than just math + verifier_type = cfg.get("verifier_type", "math") + assert isinstance(verifier_type, str), ( + f"{verifier_type=} must be a string but was {type(verifier_type)}" + ) + + worker_cls = { + "math": HFMultiRewardVerifyWorker, + }[verifier_type] + self.workers = [ + worker_cls.options( # type: ignore # (decorated with @ray.remote) + runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} + ).remote() + for _ in range(self.num_workers) + ] + + def shutdown(self) -> None: + # shutdown all workers + for worker in self.workers: + ray.kill(worker) + + def step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[MathEnvironmentMetadata], + return_extracted_answer: bool = False, + ) -> EnvironmentReturn[MathEnvironmentMetadata]: + """Runs a step in the math environment. + + Args: + message_log: list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM. + metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k. + + Returns: + EnvironmentReturn: A tuple containing: + - list[dict[str, str]]: Observations/responses batch + - list[dict]: Updated metadata + - list[str]: Next stop strings for the next turn + - Tensor: Rewards tensor + - Tensor: Done flags tensor + """ + # Extract the assistant's responses from the message history + # Each message list should have at least one assistant response + assistant_response_batch = [] + for conversation in message_log_batch: + assistant_responses = [ + str(interaction["content"]) + for interaction in conversation + if interaction["role"] == "assistant" + ] + assistant_response_batch.append("".join(assistant_responses)) + + ground_truths = [g["ground_truth"] for g in metadata] + + chunked_assistant_response_batch = chunk_list_to_workers( + assistant_response_batch, self.num_workers + ) + chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers) + + # Process each chunk in parallel + futures = [ + self.workers[i].verify.remote( + chunk, + ground_truth_chunk, + return_extracted_answer, + math_verify_impl=self.cfg.get("math_verify_impl", "hf_math_verify"), + ) + for i, (chunk, ground_truth_chunk) in enumerate( + zip(chunked_assistant_response_batch, chunked_ground_truths) + ) + ] + + worker_results = ray.get(futures) + + # Flatten the results and extract both scores and answers + number_of_rewards = 3 + results = [[]for i in range(number_of_rewards)] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + + for worker_result in worker_results: + if return_extracted_answer: + raise NotImplementedError("Skip return_extracted_answer handling") + else: + for i in range(number_of_rewards): + results[i].extend(worker_result[i]) + + observations = [ + { + "role": "environment", + "content": "Environment: correct" + if result + else "Environment: incorrect", + } + for result in results[0] ## index 0 always store corretness reward + ] + + # create a tensor of rewards and done flags + rewards = torch.tensor(results).T.cpu() ## Shape Batch_size, Number_rewards + ## hard fixed this done to + done = torch.ones(rewards.shape[0]).cpu() + next_stop_strings = [None] * len(message_log_batch) + + return EnvironmentReturn( + observations=observations, + metadata=metadata, + next_stop_strings=next_stop_strings, + rewards=rewards, + terminateds=done, + answers=extracted_answers, + ) + + def global_post_process_and_metrics( + self, batch: BatchedDataDict[Any] + ) -> tuple[BatchedDataDict[Any], dict[str, float | int]]: + """Computes metrics for this environment given a global rollout batch. + + Every rank will run this function, so you're free to use distributed + calculations if you'd prefer for heavy metrics. + """ + batch["rewards"] = ( + batch["rewards"] * batch["is_end"] + ) # set a reward of 0 for any incorrectly ended sequences + if (batch["rewards"] == 1).float().sum() > 0: + correct_solution_generation_lengths = ( + (batch["generation_lengths"] - batch["prompt_lengths"])[ + batch["rewards"] == 1 + ] + .float() + .mean() + .item() + ) + else: + correct_solution_generation_lengths = 0 + + metrics = { + # "table": table, TODO @sahilj WIP + "accuracy": batch["rewards"].mean().item(), + "pass@samples_per_prompt": calculate_pass_rate_per_prompt( + batch["text"], batch["rewards"] + ), + "fraction_of_samples_properly_ended": batch["is_end"].float().mean().item(), + "num_problems_in_batch": batch["is_end"].shape[0], + "generation_lengths": batch["generation_lengths"].float().mean().item(), + "prompt_lengths": batch["prompt_lengths"].float().mean().item(), + "correct_solution_generation_lengths": correct_solution_generation_lengths, + } + + return batch, metrics diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 603a972095..92c571f380 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -327,7 +327,9 @@ def calculate_rewards( sorted_indices = sorted( range(len(all_indices_order)), key=lambda k: all_indices_order[k] ) - rewards = torch.tensor([all_rewards[i] for i in sorted_indices]) + # Stack rewards: each element may be scalar (single-reward env) or 1d (multi-reward env). + # torch.stack preserves shape: scalars -> (N,), shape (K,) -> (N, K). + rewards = torch.stack([all_rewards[i] for i in sorted_indices]) env_observations = [all_env_observations[i] for i in sorted_indices] terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices]) next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] @@ -374,6 +376,10 @@ def run_multi_turn_rollout( active_indices = torch.arange(batch_size) total_rewards = torch.zeros(batch_size, dtype=torch.float32) + # Multi_rewards: number of components inferred from first env_output (1 for single-reward envs) + number_of_rewards: int | None = None + multi_rewards: torch.Tensor | None = None + # Initialize stop_strings from the initial batch if present current_stop_strings = current_batch.get("stop_strings", [None] * batch_size) @@ -459,7 +465,25 @@ def run_multi_turn_rollout( # Calculate rewards and get environment feedback env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env) - total_rewards[active_indices] += env_output.rewards + # Infer number of reward components on first turn (supports single- and multi-reward envs) + if number_of_rewards is None: + number_of_rewards = ( + int(env_output.rewards.shape[1]) + if env_output.rewards.ndim >= 2 + else 1 + ) + multi_rewards = torch.zeros( + batch_size, number_of_rewards, dtype=torch.float32 + ) + # Accumulate rewards: env may return shape (N,) or (N, K) + if env_output.rewards.ndim >= 2: + multi_rewards[active_indices] += env_output.rewards + total_rewards[active_indices] += env_output.rewards.sum(dim=1) + else: + multi_rewards[active_indices, 0] += env_output.rewards + total_rewards[active_indices] += env_output.rewards + + # Update message log for ALL active samples with env observation # This must happen BEFORE filtering based on done flags @@ -536,6 +560,11 @@ def run_multi_turn_rollout( # Add total rewards to the final batch current_batch["total_reward"] = total_rewards current_batch["truncated"] = sample_truncated + # Expose per-component rewards (reward1, reward2, ... rewardN); single-reward envs get reward1 only + if multi_rewards is not None: + num_reward_components = multi_rewards.shape[1] + for i in range(num_reward_components): + current_batch[f"reward{i + 1}"] = multi_rewards[:, i].clone() # Calculate aggregate metrics rollout_metrics = { @@ -666,6 +695,8 @@ async def run_sample_multi_turn_rollout( # Sample-level metrics total_reward = 0.0 + reward_acc_list: list[float] = [] # per-component rewards, length set on first multi-reward + multi_reward_seen = False turn_count = 0 token_count = 0 assistant_token_count = 0 @@ -738,8 +769,17 @@ async def run_sample_multi_turn_rollout( # Get environment feedback env_output = calculate_rewards(sample_batch, task_to_env) - # Update total reward - total_reward += float(env_output.rewards[0].item()) + # Update total reward and optional per-reward signals (reward1, reward2, ... rewardN) + if env_output.rewards.ndim == 2 and env_output.rewards.shape[1] >= 1: + multi_reward_seen = True + n = env_output.rewards.shape[1] + if len(reward_acc_list) == 0: + reward_acc_list = [0.0] * n + total_reward += float(env_output.rewards[0].sum().item()) + for j in range(n): + reward_acc_list[j] += float(env_output.rewards[0, j].item()) + else: + total_reward += float(env_output.rewards[0].item()) # Check termination terminated = env_output.terminateds[0].item() env_obs_content = env_output.observations[0]["content"] @@ -789,6 +829,9 @@ async def run_sample_multi_turn_rollout( "stop_strings": current_stop_strings, "idx": sample_idx, } + if multi_reward_seen: + for j in range(len(reward_acc_list)): + final_sample_state[f"reward{j + 1}"] = torch.tensor(reward_acc_list[j]) # Sample metrics sample_metrics = { @@ -892,25 +935,35 @@ async def run_single_sample_with_error_handling(i, sample_state): # Reconstruct batch from sample results batch_size = len(final_sample_states) - final_batch = BatchedDataDict[DatumSpec]( - { - "message_log": [state["message_log"] for state in final_sample_states], - "extra_env_info": [ - state["extra_env_info"] for state in final_sample_states - ], - "task_name": [state["task_name"] for state in final_sample_states], - "total_reward": torch.stack( - [state["total_reward"] for state in final_sample_states] - ), - "idx": [ - state.get("idx", i) for i, state in enumerate(final_sample_states) - ], - "truncated": torch.tensor( - [metrics["truncated"] for metrics in all_sample_metrics], - dtype=torch.bool, - ), - } - ) + final_batch_dict = { + "message_log": [state["message_log"] for state in final_sample_states], + "extra_env_info": [ + state["extra_env_info"] for state in final_sample_states + ], + "task_name": [state["task_name"] for state in final_sample_states], + "total_reward": torch.stack( + [state["total_reward"] for state in final_sample_states] + ), + "idx": [ + state.get("idx", i) for i, state in enumerate(final_sample_states) + ], + "truncated": torch.tensor( + [metrics["truncated"] for metrics in all_sample_metrics], + dtype=torch.bool, + ), + } + + # Add any reward component keys (reward1, reward2, ...) from the first state + reward_keys = [ + k for k in final_sample_states[0] + if k.startswith("reward") and k[6:].isdigit() + ] + reward_keys = sorted(reward_keys, key=lambda k: int(k[6:])) + for key in reward_keys: + final_batch_dict[key] = torch.stack( + [state[key] for state in final_sample_states] + ) + final_batch = BatchedDataDict[DatumSpec](final_batch_dict) # Preserve additional fields from the original input_batch for key in input_batch.keys(): @@ -1185,28 +1238,42 @@ def run_async_nemo_gym_rollout( ) input_ids = batched_flat["token_ids"] - final_batch = BatchedDataDict[DatumSpec]( - { - "agent_ref": [r["agent_ref"] for r in results], - "message_log": [r["message_log"] for r in results], - # length is used downstream for mean_prompt_length - "length": torch.tensor( - [len(r["input_message_log"][0]["token_ids"]) for r in results] - ), - "loss_multiplier": input_batch["loss_multiplier"], - # Unnecessary parts of the DatumSpec unused by the GRPO algorithm - # extra_env_info: dict[str, Any] - # idx: int - # task_name: NotRequired[str] - # stop_strings: NotRequired[list[str]] # Optional stop strings for generation - # Extra information not in the DatumSpec used by the GRPO algorithm - "total_reward": torch.tensor([r["full_result"]["reward"] for r in results]), - # Add truncated field to match other rollout paths (reusing hit_max_tokens logic) - "truncated": torch.tensor( - [m["hit_max_tokens"] for m in all_sample_metrics], dtype=torch.bool - ), - } - ) + final_batch_dict = { + "agent_ref": [r["agent_ref"] for r in results], + "message_log": [r["message_log"] for r in results], + # length is used downstream for mean_prompt_length + "length": torch.tensor( + [len(r["input_message_log"][0]["token_ids"]) for r in results] + ), + "loss_multiplier": input_batch["loss_multiplier"], + # Unnecessary parts of the DatumSpec unused by the GRPO algorithm + # extra_env_info: dict[str, Any] + # idx: int + # task_name: NotRequired[str] + # stop_strings: NotRequired[list[str]] # Optional stop strings for generation + # Extra information not in the DatumSpec used by the GRPO algorithm + "total_reward": torch.tensor([r["full_result"]["reward"] for r in results]), + # Add truncated field to match other rollout paths (reusing hit_max_tokens logic) + "truncated": torch.tensor( + [m["hit_max_tokens"] for m in all_sample_metrics], dtype=torch.bool + ), + } + + # Add any reward component keys (reward1, reward2, ...) from full_result + if results: + full_result = results[0].get("full_result", {}) + reward_keys = sorted( + [ + k for k in full_result + if isinstance(k, str) and k.startswith("reward") and k[6:].isdigit() + ], + key=lambda k: int(k[6:]), + ) + for key in reward_keys: + final_batch_dict[key] = torch.tensor( + [r["full_result"][key] for r in results] + ) + final_batch = BatchedDataDict[DatumSpec](final_batch_dict) return AsyncNemoGymRolloutResult( input_ids=input_ids, diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index bee4d8d2eb..af8919df5e 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -45,6 +45,7 @@ run_test uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/eval.sh run_test uv run --no-sync bash ./tests/functional/eval_async.sh +run_test uv run --no-sync bash ./tests/functional/gdpo.sh run_test fast uv run --no-sync bash ./tests/functional/grpo.sh run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh diff --git a/tests/functional/gdpo.sh b/tests/functional/gdpo.sh new file mode 100644 index 0000000000..80d57405ee --- /dev/null +++ b/tests/functional/gdpo.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_gdpo_gsm8k.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + cluster.num_nodes=1 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/gen_kl_error"]) < 0.001' \ + 'min(data["train/probs_ratio_clamped_min"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_min"]) < 1.21' \ + 'min(data["train/probs_ratio_clamped_max"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_max"]) < 1.21' From 295e31568e51279a3751f163369d1bf9d36f563d Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 07:12:55 -0800 Subject: [PATCH 02/10] update dataset to new structure Signed-off-by: Yuki Huang --- examples/configs/gdpo_math_1B.yaml | 21 +++++++--- .../data/datasets/response_datasets/gsm8k.py | 39 +++++++++---------- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/examples/configs/gdpo_math_1B.yaml b/examples/configs/gdpo_math_1B.yaml index 3536e608f8..a974e6ffc3 100644 --- a/examples/configs/gdpo_math_1B.yaml +++ b/examples/configs/gdpo_math_1B.yaml @@ -24,14 +24,25 @@ policy: # GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default. data: _override_: true + max_input_seq_length: ${policy.max_total_sequence_length} - prompt_file: "examples/prompts/cot.txt" - system_prompt_file: "examples/prompts/gsm8k.txt" shuffle: true num_workers: 1 - processor: "math_gdpo_data_processor" - env_name: "math" - dataset_name: "gsm8k" + + use_multiple_dataloader: false + + train: + dataset_name: "gsm8k" + split: train + validation: + dataset_name: "gsm8k" + split: test + + default: + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: "examples/prompts/gsm8k.txt" + processor: "math_gdpo_data_processor" + env_name: "math" env: math: diff --git a/nemo_rl/data/datasets/response_datasets/gsm8k.py b/nemo_rl/data/datasets/response_datasets/gsm8k.py index 970e8a076e..f1b4dac89b 100644 --- a/nemo_rl/data/datasets/response_datasets/gsm8k.py +++ b/nemo_rl/data/datasets/response_datasets/gsm8k.py @@ -37,46 +37,43 @@ def _extract_hash_answer(text: str) -> str | None: class GSM8KDataset(RawDataset): - """Simple wrapper around the GSM8K dataset with train and validation splits. + """Simple wrapper around the GSM8K dataset. Args: - seed: Random seed for shuffling the training set (default 42). - system_prompt_file: Optional path to a text file containing the system prompt - (e.g. examples/prompts/gsm8k.txt). If not provided, system prompt is empty. + split: Split name for the dataset, default is "train" + extract_answer: Whether to extract the answer from the dataset, default is True """ - def __init__( - self, - seed: int = 42, + def __init__(self, + split: str = "train", + extract_answer: bool = True, system_prompt_file: str | None = None, **kwargs, ) -> None: self.task_name = "gsm8k" + self.extract_answer = extract_answer self._system_prompt = _load_system_prompt(system_prompt_file) - # Load from HuggingFace - train_ds = load_dataset("openai/gsm8k", "main")["train"] - val_ds = load_dataset("openai/gsm8k", "main")["test"] + # load from huggingface + self.dataset = load_dataset("openai/gsm8k", "main")[split] - # Shuffle training with seed - train_ds = train_ds.shuffle(seed=seed) - - # Format the datasets - self.dataset = train_ds.map( - self.format_data, - remove_columns=train_ds.column_names, - ) - self.val_dataset = val_ds.map( + # format the dataset + self.dataset = self.dataset.map( self.format_data, - remove_columns=val_ds.column_names, + remove_columns=self.dataset.column_names, ) def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + if self.extract_answer: + answer = _extract_hash_answer(data["answer"]) + else: + answer = data["answer"] + return { "messages": [ {"role": "system", "content": self._system_prompt}, {"role": "user", "content": data["question"]}, - {"role": "assistant", "content": _extract_hash_answer(data["answer"])}, + {"role": "assistant", "content": answer}, ], "task_name": self.task_name, } From 578698e57d301a5ed9e2ca2f9a20a115536f16f9 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 07:36:16 -0800 Subject: [PATCH 03/10] register env and remove run_gdpo_gsm8k.py Signed-off-by: Yuki Huang --- examples/configs/gdpo_math_1B.yaml | 4 +- examples/run_gdpo_gsm8k.py | 258 ----------------------------- nemo_rl/environments/utils.py | 3 + tests/functional/gdpo.sh | 3 +- 4 files changed, 7 insertions(+), 261 deletions(-) delete mode 100644 examples/run_gdpo_gsm8k.py diff --git a/examples/configs/gdpo_math_1B.yaml b/examples/configs/gdpo_math_1B.yaml index a974e6ffc3..4c1674b441 100644 --- a/examples/configs/gdpo_math_1B.yaml +++ b/examples/configs/gdpo_math_1B.yaml @@ -42,10 +42,10 @@ data: prompt_file: "examples/prompts/cot.txt" system_prompt_file: "examples/prompts/gsm8k.txt" processor: "math_gdpo_data_processor" - env_name: "math" + env_name: "math_multi_reward" env: - math: + math_multi_reward: num_workers: 8 math_verify_impl: "hf_math_verify" diff --git a/examples/run_gdpo_gsm8k.py b/examples/run_gdpo_gsm8k.py deleted file mode 100644 index 14bc55cf3e..0000000000 --- a/examples/run_gdpo_gsm8k.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import pprint -from collections import defaultdict -from typing import Any, Optional - -from omegaconf import OmegaConf -from transformers import PreTrainedTokenizerBase - -from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup -from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.data import DataConfig -from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset -from nemo_rl.data.interfaces import ( - TaskDataProcessFnCallable, - TaskDataSpec, -) -from nemo_rl.data.processors import math_gdpo_data_processor -from nemo_rl.distributed.ray_actor_environment_registry import ( - get_actor_python_env, -) -from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.environments.math_environment import MathMultiRewardEnvironment -from nemo_rl.models.generation import configure_generation_config -from nemo_rl.utils.config import load_config, parse_hydra_overrides -from nemo_rl.utils.logger import get_next_experiment_dir - -OmegaConf.register_new_resolver("mul", lambda a, b: a * b) - - -def parse_args() -> tuple[argparse.Namespace, list[str]]: - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Run GRPO training with configuration") - parser.add_argument( - "--config", type=str, default=None, help="Path to YAML config file" - ) - - # Parse known args for the script - args, overrides = parser.parse_known_args() - - return args, overrides - - -# =============================================================================== -# Math Data Processor -# =============================================================================== -TokenizerType = PreTrainedTokenizerBase - - -def setup_data( - tokenizer: TokenizerType, - data_config: DataConfig, - env_configs: dict[str, Any], - seed: int, -) -> tuple[ - AllTaskProcessedDataset, - Optional[AllTaskProcessedDataset], - dict[str, EnvironmentInterface], - dict[str, EnvironmentInterface], -]: - print("\nā–¶ Setting up data...") - math_task_spec = TaskDataSpec( - task_name="math", - prompt_file=data_config["prompt_file"], - system_prompt_file=data_config["system_prompt_file"], - ) - - # load dataset - data: Any = load_response_dataset(data_config) - task_name = ( - data.task_name if hasattr(data, "task_name") else data.task_spec.task_name - ) - - # data processor - task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( - defaultdict(lambda: (math_task_spec, math_gdpo_data_processor)) - ) - task_data_processors[task_name] = (math_task_spec, math_gdpo_data_processor) - - # setup math environment - math_env = MathMultiRewardEnvironment.options( # type: ignore # it's wrapped with ray.remote - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.math_environment.MathMultiRewardEnvironment" - ), - "env_vars": dict(os.environ), # Pass thru all user environment variables - } - ).remote(env_configs["math"]) - - dataset = AllTaskProcessedDataset( - data.dataset, - tokenizer, - math_task_spec, - task_data_processors, - max_seq_length=data_config["max_input_seq_length"], - ) - - val_dataset: Optional[AllTaskProcessedDataset] = None - if data.val_dataset is not None: - val_dataset = AllTaskProcessedDataset( - data.val_dataset, - tokenizer, - math_task_spec, - task_data_processors, - max_seq_length=data_config["max_input_seq_length"], - ) - - task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) - task_to_env[task_name] = math_env - return dataset, val_dataset, task_to_env, task_to_env - - -def main() -> None: - """Main entry point.""" - # Parse arguments - args, overrides = parse_args() - - if not args.config: - args.config = os.path.join( - os.path.dirname(__file__), "configs", "gdpo_math_1B.yaml" - ) - - config = load_config(args.config) - print(f"Loaded configuration from: {args.config}") - - if overrides: - print(f"Overrides: {overrides}") - config = parse_hydra_overrides(config, overrides) - - config: MasterConfig = OmegaConf.to_container(config, resolve=True) - print("Applied CLI overrides") - - # Print config - print("Final config:") - pprint.pprint(config) - - # Get the next experiment directory with incremented ID - config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) - print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") - if config["checkpointing"]["enabled"]: - print( - f"šŸ“Š Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" - ) - - init_ray() - - # setup tokenizer - tokenizer = get_tokenizer(config["policy"]["tokenizer"]) - assert config["policy"]["generation"] is not None, ( - "A generation config is required for GRPO" - ) - config["policy"]["generation"] = configure_generation_config( - config["policy"]["generation"], tokenizer - ) - - # setup data - ( - dataset, - val_dataset, - task_to_env, - val_task_to_env, - ) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"]) - - ( - policy, - policy_generation, - cluster, - dataloader, - val_dataloader, - loss_fn, - logger, - checkpointer, - grpo_state, - master_config, - ) = setup(config, tokenizer, dataset, val_dataset) - - # Check if async mode is enabled - if "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]: - # Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features) - unsupported_features = [ - "use_dynamic_sampling", - "reward_scaling", - "reward_shaping", - ] - - for feature in unsupported_features: - if feature not in config["grpo"]: - continue - - if feature == "use_dynamic_sampling": - if config["grpo"][feature]: - raise NotImplementedError( - f"{feature} is not supported with async GRPO" - ) - else: - if config["grpo"][feature]["enabled"]: - raise NotImplementedError( - f"{feature} is not supported with async GRPO" - ) - - from nemo_rl.algorithms.grpo import async_grpo_train - - print("šŸš€ Running async GRPO training") - - async_config = config["grpo"]["async_grpo"] - # Run async GRPO training - async_grpo_train( - policy=policy, - policy_generation=policy_generation, - dataloader=dataloader, - val_dataloader=val_dataloader, - tokenizer=tokenizer, - loss_fn=loss_fn, - task_to_env=task_to_env, - val_task_to_env=val_task_to_env, - logger=logger, - checkpointer=checkpointer, - grpo_save_state=grpo_state, - master_config=master_config, - max_trajectory_age_steps=async_config["max_trajectory_age_steps"], - ) - else: - print("šŸš€ Running synchronous GRPO training") - - # Run standard GRPO training - grpo_train( - policy, - policy_generation, - dataloader, - val_dataloader, - tokenizer, - loss_fn, - task_to_env, - val_task_to_env, - logger, - checkpointer, - grpo_state, - master_config, - ) - - -if __name__ == "__main__": - main() diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index c4227b8631..df82c7d1af 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -35,6 +35,9 @@ class EnvRegistryEntry(TypedDict, total=False): "math": { "actor_class_fqn": "nemo_rl.environments.math_environment.MathEnvironment", }, + "math_multi_reward": { + "actor_class_fqn": "nemo_rl.environments.math_environment.MathMultiRewardEnvironment", + }, "code": { "actor_class_fqn": "nemo_rl.environments.code_environment.CodeEnvironment", }, diff --git a/tests/functional/gdpo.sh b/tests/functional/gdpo.sh index 80d57405ee..ee95645e28 100644 --- a/tests/functional/gdpo.sh +++ b/tests/functional/gdpo.sh @@ -19,7 +19,8 @@ mkdir -p $EXP_DIR $LOG_DIR cd $PROJECT_ROOT uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ - $PROJECT_ROOT/examples/run_gdpo_gsm8k.py \ + $PROJECT_ROOT/examples/run_grpo.py \ + --config $PROJECT_ROOT/examples/configs/gdpo_math_1B.yaml \ policy.model_name=Qwen/Qwen3-0.6B \ grpo.num_prompts_per_step=2 \ grpo.num_generations_per_prompt=4 \ From 16b0d53611e0a5da4570f262f51c299cf160473d Mon Sep 17 00:00:00 2001 From: Shih-Yang Liu Date: Wed, 4 Mar 2026 02:04:21 -0800 Subject: [PATCH 04/10] align compute_advantage interface and move the initialization forward --- .../Megatron-Bridge-workspace/Megatron-Bridge | 2 +- 3rdparty/Megatron-LM-workspace/Megatron-LM | 2 +- nemo_rl/algorithms/advantage_estimator.py | 39 ++---- nemo_rl/algorithms/grpo.py | 120 ++++-------------- nemo_rl/algorithms/utils.py | 4 + nemo_rl/experience/rollouts.py | 19 ++- tests/unit/algorithms/test_grpo.py | 21 ++- 7 files changed, 69 insertions(+), 138 deletions(-) diff --git a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge index f91542b909..15398e08fc 160000 --- a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge +++ b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge @@ -1 +1 @@ -Subproject commit f91542b90908ad08b7e13672feea03e27bedee27 +Subproject commit 15398e08fc86be3de084c7382116527246ab1852 diff --git a/3rdparty/Megatron-LM-workspace/Megatron-LM b/3rdparty/Megatron-LM-workspace/Megatron-LM index 23dd639cf3..193463c4f8 160000 --- a/3rdparty/Megatron-LM-workspace/Megatron-LM +++ b/3rdparty/Megatron-LM-workspace/Megatron-LM @@ -1 +1 @@ -Subproject commit 23dd639cf3de30f3b9d8d0fae71ee31180be9ddd +Subproject commit 193463c4f8414e6906a40dd527a450bca50706b1 diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py index d8e9767a4b..95588aeac1 100644 --- a/nemo_rl/algorithms/advantage_estimator.py +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -26,14 +26,7 @@ import re import torch -from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl - - -def _get_reward_component_keys(batch) -> list: - """Return batch keys that are reward components (reward1, reward2, ...) in sorted order.""" - keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))] - return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group())) - +from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl, get_gdpo_reward_component_keys class GRPOAdvantageEstimator: """GRPO-style advantage estimator with leave-one-out baseline. @@ -45,12 +38,11 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] self.normalize_rewards = estimator_config["normalize_rewards"] - def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): + def compute_advantage(self, repeated_batch, mask, **kwargs): """Compute GRPO advantages. Args: - prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. - rewards: Tensor of shape [batch_size] containing reward for each sample. + repeated_batch: Batch containing _input_ids_for_baseline and total_reward. mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. Used only for expanding advantages to token-level shape. **kwargs: Additional arguments (unused). @@ -58,6 +50,8 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): Returns: Advantages tensor of shape [batch_size, seq_len]. """ + prompt_ids = repeated_batch["_input_ids_for_baseline"] + rewards = repeated_batch["total_reward"] baseline, std = calculate_baseline_and_std_per_prompt( prompt_ids, rewards, @@ -99,10 +93,12 @@ def compute_advantage(self, repeated_batch, mask, **kwargs): Returns: Advantages tensor of shape [batch_size, seq_len]. """ - reward_component_keys = _get_reward_component_keys(repeated_batch) - if not reward_component_keys: + reward_component_keys = get_gdpo_reward_component_keys(repeated_batch) + if len(reward_component_keys) < 2: raise ValueError( - "GDPOAdvantageEstimator requires reward component keys (reward1, reward2, ...) in repeated_batch" + f"GDPO requires multiple reward components (reward1, reward2, ...). " + f"This batch has {len(reward_component_keys)} component(s). " + "Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config." ) current_input_ids = repeated_batch["_input_ids_for_baseline"] valid = torch.ones_like( @@ -158,20 +154,11 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.kl_coef = loss_config["reference_policy_kl_penalty"] self.kl_type = loss_config["reference_policy_kl_type"] - def compute_advantage( - self, - prompt_ids, - rewards, - mask, - logprobs_policy=None, - logprobs_reference=None, - **kwargs, - ): + def compute_advantage(self, repeated_batch, mask, logprobs_policy=None, logprobs_reference=None, **kwargs): """Compute Reinforce++ advantages with optional KL penalty. Args: - prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. - rewards: Tensor of shape [batch_size] containing reward for each sample. + repeated_batch: Batch containing _input_ids_for_baseline and total_reward. mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. Used for: (1) expanding advantages to token-level shape, (2) global normalization that only considers valid tokens. @@ -182,6 +169,8 @@ def compute_advantage( Returns: Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. """ + prompt_ids = repeated_batch["_input_ids_for_baseline"] + rewards = repeated_batch["total_reward"] # minus baseline if self.minus_baseline: mean, _ = calculate_baseline_and_std_per_prompt( diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 027bfb3262..588c0c9db2 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -48,6 +48,7 @@ log_generation_metrics_to_wandb, print_performance_metrics, set_seed, + get_gdpo_reward_component_keys ) from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn @@ -936,15 +937,6 @@ def dynamic_sampling( return batch_to_return, is_batch_complete, batch_cache, dynamic_sampling_metrics -def _get_reward_component_keys(batch: BatchedDataDict[Any]) -> list[str]: - """Return batch keys that are reward components (reward1, reward2, ...) in sorted order. - - Enables environments to expose any number of rewards without code changes elsewhere. - """ - keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))] - return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group())) - - def scale_rewards( repeated_batch: BatchedDataDict[DatumSpec], reward_scaling_cfg: RewardScalingConfig ) -> BatchedDataDict[DatumSpec]: @@ -983,12 +975,9 @@ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: source_max - source_min ) * (target_max - target_min) - rewards = torch.clamp(rewards, min=source_min, max=source_max) - scaled_rewards = target_min + (rewards - source_min) / ( - source_max - source_min - ) * (target_max - target_min) + scaled_rewards = _scale(rewards) repeated_batch["total_reward"] = scaled_rewards - for key in _get_reward_component_keys(repeated_batch): + for key in get_gdpo_reward_component_keys(repeated_batch): repeated_batch[key] = _scale(repeated_batch[key]) return repeated_batch @@ -1043,15 +1032,11 @@ def _should_log_nemo_gym_responses(master_config: MasterConfig) -> bool: return should_log_nemo_gym_responses -def _create_advantage_estimator( - master_config: MasterConfig, use_multi_reward_advantages: bool = False -): +def _create_advantage_estimator(master_config: MasterConfig): """Create and return an advantage estimator based on configuration. Args: master_config: The master configuration dictionary. - use_multi_reward_advantages: If True and name is "gdpo", use GDPO. - When False and name is "gdpo", use GRPO (single-reward fallback). Returns: An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus). @@ -1075,14 +1060,9 @@ def _create_advantage_estimator( ) adv_estimator_name = adv_estimator_config["name"] - # GDPO only when we have multi-reward data; otherwise "gdpo" config uses GRPO - if use_multi_reward_advantages and adv_estimator_name == "gdpo": + if adv_estimator_name == "gdpo": adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config) print(" āœ“ Using GDPO advantage estimator (multi-reward)") - elif adv_estimator_name == "gdpo": - # GDPO config but single-reward batch: use GRPO - adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) - print(" āœ“ Using GRPO advantage estimator (gdpo config, single-reward)") elif adv_estimator_name == "grpo": adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) print(" āœ“ Using GRPO advantage estimator") @@ -1392,6 +1372,10 @@ def grpo_train( val_period = master_config["grpo"]["val_period"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + # Create advantage estimator + adv_estimator = _create_advantage_estimator(master_config) + + # Run validation at the start if configured # TODO: Add validation with kv scales if needed if val_at_start and current_step == 0: @@ -1614,19 +1598,7 @@ def grpo_train( print("ā–¶ Processing rewards...,", flush=True) # GDPO with timer.time("reward_calculation"): - # Use GDPO when adv_estimator name is "gdpo" and batch has - # multiple reward components (reward1, reward2, ...). rewards = repeated_batch["total_reward"] - adv_estimator_config = master_config["grpo"].get( - "adv_estimator", {"name": "grpo"} - ) - adv_estimator_name = adv_estimator_config.get("name", "grpo") - reward_component_keys = _get_reward_component_keys(repeated_batch) - use_multi_reward_advantages = ( - adv_estimator_name == "gdpo" - and len(reward_component_keys) >= 2 - ) - # Store input_ids in batch so that after dynamic_sampling it stays aligned with # the (possibly filtered) batch: select_indices / from_batches / slice all # apply to this key, so per-reward baselines use the same prompts as reward components. @@ -1686,12 +1658,6 @@ def grpo_train( if not is_batch_complete: continue - # Create advantage estimator for this batch (GDPO when multi-reward, else GRPO/Reinforce++) - adv_estimator = _create_advantage_estimator( - master_config, - use_multi_reward_advantages=use_multi_reward_advantages, - ) - gen_step_metrics = {} if hasattr(policy_generation, "get_step_metrics"): gen_step_metrics = policy_generation.get_step_metrics() @@ -1824,27 +1790,12 @@ def grpo_train( - if use_multi_reward_advantages: - # GDPO adv_estimation - train_data["advantages"] = adv_estimator.compute_advantage( - repeated_batch=repeated_batch, - mask=mask, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) - - else: - - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) - - - + train_data["advantages"] = adv_estimator.compute_advantage( + repeated_batch=repeated_batch, + mask=mask, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) del prompt_ids_for_adv # Log rewards and advantages information @@ -2494,6 +2445,9 @@ def async_grpo_train( val_at_end = master_config["grpo"]["val_at_end"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + # Create advantage estimator + adv_estimator = _create_advantage_estimator(master_config) + assert not colocated_inference, ( "Colocated inference is not supported for async GRPO. Please use non-colocated inference." ) @@ -2784,23 +2738,8 @@ def async_grpo_train( del prompt_batched_flat rewards = repeated_batch["total_reward"] - adv_estimator_config = master_config["grpo"].get( - "adv_estimator", {"name": "grpo"} - ) - adv_estimator_name = adv_estimator_config.get("name", "grpo") - reward_component_keys = _get_reward_component_keys(repeated_batch) - use_multi_reward_advantages = ( - adv_estimator_name == "gdpo" - and len(reward_component_keys) >= 2 - ) - if use_multi_reward_advantages: - repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv - - # Create advantage estimator (GDPO when name is "gdpo" and batch has multi-reward) - adv_estimator = _create_advantage_estimator( - master_config, - use_multi_reward_advantages=use_multi_reward_advantages, - ) + # All estimators read _input_ids_for_baseline from repeated_batch + repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv print( f" šŸ“Š Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" @@ -2883,19 +2822,12 @@ def async_grpo_train( sample_mask = train_data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - if use_multi_reward_advantages: - train_data["advantages"] = adv_estimator.compute_advantage( - repeated_batch=repeated_batch, - mask=mask, - ) - else: - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) + train_data["advantages"] = adv_estimator.compute_advantage( + repeated_batch=repeated_batch, + mask=mask, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) del prompt_ids_for_adv # Log advantages stats diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 8e632ca5ee..c12366a0a8 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -30,6 +30,10 @@ from nemo_rl.models.policy import TokenizerConfig from nemo_rl.utils.logger import Logger +def get_gdpo_reward_component_keys(batch) -> list: + """Return batch keys that are reward components (reward1, reward2, ...) in sorted order.""" + keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))] + return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group())) def calculate_kl( logprobs: torch.Tensor, diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 92c571f380..52fc560881 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -467,20 +467,19 @@ def run_multi_turn_rollout( # Infer number of reward components on first turn (supports single- and multi-reward envs) if number_of_rewards is None: - number_of_rewards = ( - int(env_output.rewards.shape[1]) - if env_output.rewards.ndim >= 2 - else 1 - ) - multi_rewards = torch.zeros( - batch_size, number_of_rewards, dtype=torch.float32 - ) + if env_output.rewards.ndim >= 2: + number_of_rewards = int(env_output.rewards.shape[1]) + multi_rewards = torch.zeros( + batch_size, number_of_rewards, dtype=torch.float32 + ) + else: + number_of_rewards = 1 + # multi_rewards left None: GRPO uses total_reward only; reward1 unused # Accumulate rewards: env may return shape (N,) or (N, K) if env_output.rewards.ndim >= 2: multi_rewards[active_indices] += env_output.rewards total_rewards[active_indices] += env_output.rewards.sum(dim=1) else: - multi_rewards[active_indices, 0] += env_output.rewards total_rewards[active_indices] += env_output.rewards @@ -560,7 +559,7 @@ def run_multi_turn_rollout( # Add total rewards to the final batch current_batch["total_reward"] = total_rewards current_batch["truncated"] = sample_truncated - # Expose per-component rewards (reward1, reward2, ... rewardN); single-reward envs get reward1 only + # Expose per-component rewards (reward1, reward2, ...) for multi-reward envs only; GRPO uses total_reward if multi_rewards is not None: num_reward_components = multi_rewards.shape[1] for i in range(num_reward_components): diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 7a0783f132..e0bdb8d2b1 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -1642,8 +1642,9 @@ def test_grpo_advantage_estimator_zero_std(): [2.0, 2.0, 1.0, 3.0] ) # prompt 0: std=0; prompt 1: std=sqrt(2) mask = torch.ones(4, 5) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # prompt 0: std=0 -> skip normalization, advantage=0 (reward - mean = 0) # prompt 1: With Bessel correction for 2 samples, std = sqrt(2), normalized = ±1/sqrt(2) ā‰ˆ ±0.7071 @@ -1673,8 +1674,9 @@ def test_grpo_advantage_estimator_tensor_shapes(): prompt_ids = torch.tensor([[0], [0]]) rewards = torch.tensor([1.0, 3.0]) # mean=2, std=sqrt(2) with Bessel mask = torch.ones(2, 3) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) assert result.shape == (2, 3) # Verify normalized values: (reward - mean) / std @@ -1687,8 +1689,9 @@ def test_grpo_advantage_estimator_tensor_shapes(): prompt_ids = torch.tensor([[0]] * 10) rewards = torch.arange(10, dtype=torch.float32) # 0, 1, 2, ..., 9 mask = torch.ones(10, 5) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) assert result.shape == (10, 5) # After normalization, mean should be ~0 @@ -1712,8 +1715,9 @@ def test_grpo_advantage_estimator_negative_advantages(): prompt_ids = torch.tensor([[0], [0], [0]]) rewards = torch.tensor([0.0, 2.0, 4.0]) # mean=2, deviations: -2, 0, +2 mask = torch.ones(3, 4) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # Verify ordering: first should be negative, middle ~0, last positive assert result[0, 0] < 0 # below mean -> negative advantage @@ -1742,8 +1746,9 @@ def test_grpo_advantage_estimator_zero_std_and_zero_advantage(): prompt_ids = torch.tensor([[0], [0], [0], [0]]) rewards = torch.tensor([5.0, 5.0, 5.0, 5.0]) # all same mask = torch.ones(4, 3) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # All advantages should be exactly 0 expected = torch.zeros(4, 3) @@ -1768,8 +1773,9 @@ def test_grpo_advantage_estimator_small_nonzero_std(): prompt_ids = torch.tensor([[0], [0]]) rewards = torch.tensor([1.0, 1.01]) # small but detectable difference mask = torch.ones(2, 3) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # Even with small std, normalization should still happen # After normalization, the values should be ±1/sqrt(2) (for 2 samples with Bessel) @@ -1808,8 +1814,9 @@ def test_reinforce_plus_plus_global_normalization(): ) # Shape (4, 1) for unique prompt matching rewards = torch.tensor([0.0, 1.0, 2.0, 3.0]) # mean=1.5 mask = torch.ones(4, 5) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # After global normalization, mean should be ~0 result_mean = (result * mask).sum() / mask.sum() From 6e1fd96d3620de5723c5e5016d2cfb4d28e0ac13 Mon Sep 17 00:00:00 2001 From: Shih-Yang Liu Date: Wed, 4 Mar 2026 02:29:00 -0800 Subject: [PATCH 05/10] fix a small bug forgot to import re --- nemo_rl/algorithms/advantage_estimator.py | 1 - nemo_rl/algorithms/utils.py | 1 + pyproject.toml | 13 +++++++------ uv.lock | 4 +++- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py index 95588aeac1..976c142a44 100644 --- a/nemo_rl/algorithms/advantage_estimator.py +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -23,7 +23,6 @@ - Reinforce++: https://arxiv.org/abs/2501.03262 """ -import re import torch from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl, get_gdpo_reward_component_keys diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index c12366a0a8..bcc2281544 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -29,6 +29,7 @@ from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.models.policy import TokenizerConfig from nemo_rl.utils.logger import Logger +import re def get_gdpo_reward_component_keys(batch) -> list: """Return batch keys that are reward components (reward1, reward2, ...) in sorted order.""" diff --git a/pyproject.toml b/pyproject.toml index 5ba0782a0e..bb00194adb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ requires-python = ">=3.12" license = { text = "Apache 2.0" } dependencies = [ "setuptools", - "pip", # Required for frozen environments; uv venv --seed may not reliably install pip - "ninja", # for flash-attn parallel build + "pip", # Required for frozen environments; uv venv --seed may not reliably install pip + "ninja", # for flash-attn parallel build "torch==2.9.0", "triton; sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'aarch64')", "colored==2.2.3", @@ -44,16 +44,17 @@ dependencies = [ "sympy>=1.14.0", "pillow>=11.3.0", "torchvision>=0.22.0", - "num2words>=0.5.14", # for SmolVLM + "num2words>=0.5.14", # for SmolVLM "mlflow>=3.5.0,<3.6.0", "nvidia-nvshmem-cu12; sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'aarch64')", # for deep_ep build "swanlab", "pyzmq", "decord2", "nvidia-resiliency-ext", - "nccl4py", # for non-colocated refit - "cuda-bindings", # for non-colocated refit - "pybase64", # for sglang refit + "nccl4py", # for non-colocated refit + "cuda-bindings", # for non-colocated refit + "pybase64", # for sglang refit + "decord>=0.6.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index ade34b2f1b..e1c73b270e 100644 --- a/uv.lock +++ b/uv.lock @@ -1556,7 +1556,7 @@ name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, + { name = "numpy" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -4696,6 +4696,7 @@ dependencies = [ { name = "cuda-bindings" }, { name = "datasets" }, { name = "debugpy" }, + { name = "decord" }, { name = "decord2" }, { name = "hydra-core" }, { name = "math-verify" }, @@ -4836,6 +4837,7 @@ requires-dist = [ { name = "cuda-python", marker = "extra == 'vllm'" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy" }, + { name = "decord", specifier = ">=0.6.0" }, { name = "decord2" }, { name = "deep-ep", marker = "extra == 'automodel'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, { name = "deep-ep", marker = "extra == 'mcore'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, From e6a28ad0011276c4c252337b1a2f127e7e2c4d3e Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 07:07:01 -0800 Subject: [PATCH 06/10] revert dependency Signed-off-by: Yuki Huang --- 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge | 2 +- 3rdparty/Megatron-LM-workspace/Megatron-LM | 2 +- pyproject.toml | 13 ++++++------- uv.lock | 4 +--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge index 15398e08fc..f91542b909 160000 --- a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge +++ b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge @@ -1 +1 @@ -Subproject commit 15398e08fc86be3de084c7382116527246ab1852 +Subproject commit f91542b90908ad08b7e13672feea03e27bedee27 diff --git a/3rdparty/Megatron-LM-workspace/Megatron-LM b/3rdparty/Megatron-LM-workspace/Megatron-LM index 193463c4f8..23dd639cf3 160000 --- a/3rdparty/Megatron-LM-workspace/Megatron-LM +++ b/3rdparty/Megatron-LM-workspace/Megatron-LM @@ -1 +1 @@ -Subproject commit 193463c4f8414e6906a40dd527a450bca50706b1 +Subproject commit 23dd639cf3de30f3b9d8d0fae71ee31180be9ddd diff --git a/pyproject.toml b/pyproject.toml index bb00194adb..5ba0782a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ requires-python = ">=3.12" license = { text = "Apache 2.0" } dependencies = [ "setuptools", - "pip", # Required for frozen environments; uv venv --seed may not reliably install pip - "ninja", # for flash-attn parallel build + "pip", # Required for frozen environments; uv venv --seed may not reliably install pip + "ninja", # for flash-attn parallel build "torch==2.9.0", "triton; sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'aarch64')", "colored==2.2.3", @@ -44,17 +44,16 @@ dependencies = [ "sympy>=1.14.0", "pillow>=11.3.0", "torchvision>=0.22.0", - "num2words>=0.5.14", # for SmolVLM + "num2words>=0.5.14", # for SmolVLM "mlflow>=3.5.0,<3.6.0", "nvidia-nvshmem-cu12; sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'aarch64')", # for deep_ep build "swanlab", "pyzmq", "decord2", "nvidia-resiliency-ext", - "nccl4py", # for non-colocated refit - "cuda-bindings", # for non-colocated refit - "pybase64", # for sglang refit - "decord>=0.6.0", + "nccl4py", # for non-colocated refit + "cuda-bindings", # for non-colocated refit + "pybase64", # for sglang refit ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index e1c73b270e..ade34b2f1b 100644 --- a/uv.lock +++ b/uv.lock @@ -1556,7 +1556,7 @@ name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -4696,7 +4696,6 @@ dependencies = [ { name = "cuda-bindings" }, { name = "datasets" }, { name = "debugpy" }, - { name = "decord" }, { name = "decord2" }, { name = "hydra-core" }, { name = "math-verify" }, @@ -4837,7 +4836,6 @@ requires-dist = [ { name = "cuda-python", marker = "extra == 'vllm'" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy" }, - { name = "decord", specifier = ">=0.6.0" }, { name = "decord2" }, { name = "deep-ep", marker = "extra == 'automodel'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, { name = "deep-ep", marker = "extra == 'mcore'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, From 0428ef1bd12e6a3b19630f31dac48104b0eea263 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 08:26:28 -0800 Subject: [PATCH 07/10] update math_gdpo_data_processor and system_prompt Signed-off-by: Yuki Huang --- examples/configs/gdpo_math_1B.yaml | 2 +- .../data/datasets/response_datasets/gsm8k.py | 13 ----- nemo_rl/data/processors.py | 51 ++++++++----------- 3 files changed, 23 insertions(+), 43 deletions(-) diff --git a/examples/configs/gdpo_math_1B.yaml b/examples/configs/gdpo_math_1B.yaml index 4c1674b441..47450a17a7 100644 --- a/examples/configs/gdpo_math_1B.yaml +++ b/examples/configs/gdpo_math_1B.yaml @@ -39,7 +39,7 @@ data: split: test default: - prompt_file: "examples/prompts/cot.txt" + prompt_file: null system_prompt_file: "examples/prompts/gsm8k.txt" processor: "math_gdpo_data_processor" env_name: "math_multi_reward" diff --git a/nemo_rl/data/datasets/response_datasets/gsm8k.py b/nemo_rl/data/datasets/response_datasets/gsm8k.py index f1b4dac89b..e71428206b 100644 --- a/nemo_rl/data/datasets/response_datasets/gsm8k.py +++ b/nemo_rl/data/datasets/response_datasets/gsm8k.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any from datasets import load_dataset @@ -20,16 +19,6 @@ from nemo_rl.data.datasets.raw_dataset import RawDataset -def _load_system_prompt(system_prompt_file: str | None) -> str: - """Load system prompt from file. Returns empty string if path is None or missing.""" - if not system_prompt_file: - return "" - if os.path.exists(system_prompt_file): - with open(system_prompt_file, "r", encoding="utf-8") as f: - return f.read() - raise FileNotFoundError(f"System prompt file {system_prompt_file!r} not found.") - - def _extract_hash_answer(text: str) -> str | None: if "####" not in text: return None @@ -52,7 +41,6 @@ def __init__(self, ) -> None: self.task_name = "gsm8k" self.extract_answer = extract_answer - self._system_prompt = _load_system_prompt(system_prompt_file) # load from huggingface self.dataset = load_dataset("openai/gsm8k", "main")[split] @@ -71,7 +59,6 @@ def format_data(self, data: dict[str, Any]) -> dict[str, Any]: return { "messages": [ - {"role": "system", "content": self._system_prompt}, {"role": "user", "content": data["question"]}, {"role": "assistant", "content": answer}, ], diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 57a88ba644..5d3f51cccf 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -381,6 +381,7 @@ def math_data_processor( return output +# TODO: @yukih: unify to math_hf_data_processor once https://github.com/NVIDIA-NeMo/RL/issues/2060 is resolved. def math_gdpo_data_processor( datum_dict: dict[str, Any], task_data_spec: TaskDataSpec, @@ -389,44 +390,36 @@ def math_gdpo_data_processor( idx: int, ) -> DatumSpec: """Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment.""" - user_message = datum_dict["messages"] - - # print(f"user_message {user_message}") - - problem = user_message[1]["content"] - - extra_env_info = {"ground_truth": user_message[2]["content"]} - - message_log: LLMMessageLogType = [] - - - system_message = { - "role": "system", - "content": user_message[0]["content"] - - } - - user_message = { - "role": "user", - "content": problem, - } + problem = user_message[0]["content"] + extra_env_info = {"ground_truth": user_message[1]["content"]} + # merge system prompt and user prompt + message_list = [] + # system prompt + if task_data_spec.system_prompt: + message_list.append({ + "role": "system", + "content": task_data_spec.system_prompt, + }) + # user prompt + if task_data_spec.prompt: + problem = task_data_spec.prompt.format(problem) + message_list.append({"role": "user", "content": problem}) message: list[str] = tokenizer.apply_chat_template( # type: ignore - [system_message, user_message], + message_list, tokenize=False, add_generation_prompt=True, add_special_tokens=False, ) - - user_message["token_ids"] = tokenizer( - message, - return_tensors="pt", - add_special_tokens=False, + token_ids = tokenizer( + message, return_tensors="pt", add_special_tokens=False )["input_ids"][0] - user_message["content"] = message - message_log.append(user_message) + + message_log: LLMMessageLogType = [ + {"role": "user", "content": message, "token_ids": token_ids} + ] length = sum(len(m["token_ids"]) for m in message_log) From 665219df950891552ca26cfc38ecb747985ab856 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 08:43:19 -0800 Subject: [PATCH 08/10] add assert and revert some missing comments Signed-off-by: Yuki Huang --- examples/run_grpo.py | 7 +++++++ nemo_rl/algorithms/grpo.py | 16 ++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/run_grpo.py b/examples/run_grpo.py index 6130b99018..0e9f8bf24a 100644 --- a/examples/run_grpo.py +++ b/examples/run_grpo.py @@ -139,6 +139,13 @@ def main() -> None: "use_multiple_dataloader is not supported with async GRPO" ) + # Async GDPO is not supported + if config["grpo"]["adv_estimator"]["name"] == "gdpo": + raise NotImplementedError( + "GDPO is not supported for async training, " + "please set grpo.async_grpo.enabled to false in your config." + ) + from nemo_rl.algorithms.grpo import async_grpo_train print("šŸš€ Running async GRPO training") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 588c0c9db2..f5b8892ed7 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1046,7 +1046,10 @@ def _create_advantage_estimator(master_config: MasterConfig): """ grpo_config = master_config["grpo"] loss_config = master_config["loss_fn"] + # Provide backward-compatible defaults when adv_estimator is not in config. + # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline + # which older configs still use. adv_estimator_config = grpo_config.get( "adv_estimator", { @@ -1061,6 +1064,10 @@ def _create_advantage_estimator(master_config: MasterConfig): adv_estimator_name = adv_estimator_config["name"] if adv_estimator_name == "gdpo": + assert not _should_use_async_rollouts(master_config), ( + "GDPO is not supported for async rollouts, " + "please set policy.generation.vllm_cfg.async_engine to false in your config." + ) adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config) print(" āœ“ Using GDPO advantage estimator (multi-reward)") elif adv_estimator_name == "grpo": @@ -1372,10 +1379,9 @@ def grpo_train( val_period = master_config["grpo"]["val_period"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - # Create advantage estimator + # Initialize advantage estimator adv_estimator = _create_advantage_estimator(master_config) - # Run validation at the start if configured # TODO: Add validation with kv scales if needed if val_at_start and current_step == 0: @@ -1596,8 +1602,8 @@ def grpo_train( # Calculate rewards & advantages memory_tracker.snapshot_start_of_stage("Processing rewards", dir()) print("ā–¶ Processing rewards...,", flush=True) - # GDPO with timer.time("reward_calculation"): + # Extract rewards from final_batch rewards = repeated_batch["total_reward"] # Store input_ids in batch so that after dynamic_sampling it stays aligned with # the (possibly filtered) batch: select_indices / from_batches / slice all @@ -1788,8 +1794,6 @@ def grpo_train( sample_mask = train_data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - - train_data["advantages"] = adv_estimator.compute_advantage( repeated_batch=repeated_batch, mask=mask, @@ -2445,7 +2449,7 @@ def async_grpo_train( val_at_end = master_config["grpo"]["val_at_end"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - # Create advantage estimator + # Initialize advantage estimator adv_estimator = _create_advantage_estimator(master_config) assert not colocated_inference, ( From b7597c656c82d347faabc62bee89db1e720e7f7a Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 08:47:49 -0800 Subject: [PATCH 09/10] revert async change and enable fast test Signed-off-by: Yuki Huang --- .../data/datasets/response_datasets/gsm8k.py | 2 +- nemo_rl/experience/rollouts.py | 106 +++++++----------- tests/functional/L1_Functional_Tests_GPU.sh | 2 +- 3 files changed, 43 insertions(+), 67 deletions(-) diff --git a/nemo_rl/data/datasets/response_datasets/gsm8k.py b/nemo_rl/data/datasets/response_datasets/gsm8k.py index e71428206b..ce3affd869 100644 --- a/nemo_rl/data/datasets/response_datasets/gsm8k.py +++ b/nemo_rl/data/datasets/response_datasets/gsm8k.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 52fc560881..52f3996244 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -934,35 +934,25 @@ async def run_single_sample_with_error_handling(i, sample_state): # Reconstruct batch from sample results batch_size = len(final_sample_states) - final_batch_dict = { - "message_log": [state["message_log"] for state in final_sample_states], - "extra_env_info": [ - state["extra_env_info"] for state in final_sample_states - ], - "task_name": [state["task_name"] for state in final_sample_states], - "total_reward": torch.stack( - [state["total_reward"] for state in final_sample_states] - ), - "idx": [ - state.get("idx", i) for i, state in enumerate(final_sample_states) - ], - "truncated": torch.tensor( - [metrics["truncated"] for metrics in all_sample_metrics], - dtype=torch.bool, - ), - } - - # Add any reward component keys (reward1, reward2, ...) from the first state - reward_keys = [ - k for k in final_sample_states[0] - if k.startswith("reward") and k[6:].isdigit() - ] - reward_keys = sorted(reward_keys, key=lambda k: int(k[6:])) - for key in reward_keys: - final_batch_dict[key] = torch.stack( - [state[key] for state in final_sample_states] - ) - final_batch = BatchedDataDict[DatumSpec](final_batch_dict) + final_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [state["message_log"] for state in final_sample_states], + "extra_env_info": [ + state["extra_env_info"] for state in final_sample_states + ], + "task_name": [state["task_name"] for state in final_sample_states], + "total_reward": torch.stack( + [state["total_reward"] for state in final_sample_states] + ), + "idx": [ + state.get("idx", i) for i, state in enumerate(final_sample_states) + ], + "truncated": torch.tensor( + [metrics["truncated"] for metrics in all_sample_metrics], + dtype=torch.bool, + ), + } + ) # Preserve additional fields from the original input_batch for key in input_batch.keys(): @@ -1237,42 +1227,28 @@ def run_async_nemo_gym_rollout( ) input_ids = batched_flat["token_ids"] - final_batch_dict = { - "agent_ref": [r["agent_ref"] for r in results], - "message_log": [r["message_log"] for r in results], - # length is used downstream for mean_prompt_length - "length": torch.tensor( - [len(r["input_message_log"][0]["token_ids"]) for r in results] - ), - "loss_multiplier": input_batch["loss_multiplier"], - # Unnecessary parts of the DatumSpec unused by the GRPO algorithm - # extra_env_info: dict[str, Any] - # idx: int - # task_name: NotRequired[str] - # stop_strings: NotRequired[list[str]] # Optional stop strings for generation - # Extra information not in the DatumSpec used by the GRPO algorithm - "total_reward": torch.tensor([r["full_result"]["reward"] for r in results]), - # Add truncated field to match other rollout paths (reusing hit_max_tokens logic) - "truncated": torch.tensor( - [m["hit_max_tokens"] for m in all_sample_metrics], dtype=torch.bool - ), - } - - # Add any reward component keys (reward1, reward2, ...) from full_result - if results: - full_result = results[0].get("full_result", {}) - reward_keys = sorted( - [ - k for k in full_result - if isinstance(k, str) and k.startswith("reward") and k[6:].isdigit() - ], - key=lambda k: int(k[6:]), - ) - for key in reward_keys: - final_batch_dict[key] = torch.tensor( - [r["full_result"][key] for r in results] - ) - final_batch = BatchedDataDict[DatumSpec](final_batch_dict) + final_batch = BatchedDataDict[DatumSpec]( + { + "agent_ref": [r["agent_ref"] for r in results], + "message_log": [r["message_log"] for r in results], + # length is used downstream for mean_prompt_length + "length": torch.tensor( + [len(r["input_message_log"][0]["token_ids"]) for r in results] + ), + "loss_multiplier": input_batch["loss_multiplier"], + # Unnecessary parts of the DatumSpec unused by the GRPO algorithm + # extra_env_info: dict[str, Any] + # idx: int + # task_name: NotRequired[str] + # stop_strings: NotRequired[list[str]] # Optional stop strings for generation + # Extra information not in the DatumSpec used by the GRPO algorithm + "total_reward": torch.tensor([r["full_result"]["reward"] for r in results]), + # Add truncated field to match other rollout paths (reusing hit_max_tokens logic) + "truncated": torch.tensor( + [m["hit_max_tokens"] for m in all_sample_metrics], dtype=torch.bool + ), + } + ) return AsyncNemoGymRolloutResult( input_ids=input_ids, diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index af8919df5e..ced8ffacaf 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -45,7 +45,7 @@ run_test uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/eval.sh run_test uv run --no-sync bash ./tests/functional/eval_async.sh -run_test uv run --no-sync bash ./tests/functional/gdpo.sh +run_test fast uv run --no-sync bash ./tests/functional/gdpo.sh run_test fast uv run --no-sync bash ./tests/functional/grpo.sh run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh From e898d01dcbca70a61b6f27f8d20e12eab974d469 Mon Sep 17 00:00:00 2001 From: Shih-Yang Liu Date: Thu, 5 Mar 2026 01:13:38 -0800 Subject: [PATCH 10/10] addressed all comments and fixed bugs --- .../Megatron-Bridge-workspace/Megatron-Bridge | 2 +- 3rdparty/Megatron-LM-workspace/Megatron-LM | 2 +- nemo_rl/algorithms/advantage_estimator.py | 51 +++++++++++++------ nemo_rl/algorithms/grpo.py | 4 ++ nemo_rl/experience/rollouts.py | 12 +++-- pyproject.toml | 13 ++--- tests/unit/algorithms/test_grpo.py | 49 +++++++++++++++--- uv.lock | 4 +- 8 files changed, 103 insertions(+), 34 deletions(-) diff --git a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge index f91542b909..15398e08fc 160000 --- a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge +++ b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge @@ -1 +1 @@ -Subproject commit f91542b90908ad08b7e13672feea03e27bedee27 +Subproject commit 15398e08fc86be3de084c7382116527246ab1852 diff --git a/3rdparty/Megatron-LM-workspace/Megatron-LM b/3rdparty/Megatron-LM-workspace/Megatron-LM index 23dd639cf3..193463c4f8 160000 --- a/3rdparty/Megatron-LM-workspace/Megatron-LM +++ b/3rdparty/Megatron-LM-workspace/Megatron-LM @@ -1 +1 @@ -Subproject commit 23dd639cf3de30f3b9d8d0fae71ee31180be9ddd +Subproject commit 193463c4f8414e6906a40dd527a450bca50706b1 diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py index 976c142a44..e964a1d548 100644 --- a/nemo_rl/algorithms/advantage_estimator.py +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -37,20 +37,26 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] self.normalize_rewards = estimator_config["normalize_rewards"] - def compute_advantage(self, repeated_batch, mask, **kwargs): + def compute_advantage( + self, + prompt_ids, + rewards, + repeated_batch, + mask, + **kwargs, + ): """Compute GRPO advantages. Args: - repeated_batch: Batch containing _input_ids_for_baseline and total_reward. + prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. + rewards: Tensor of shape [batch_size] containing reward for each sample. + repeated_batch: Batch (unused; for interface consistency). mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. - Used only for expanding advantages to token-level shape. **kwargs: Additional arguments (unused). Returns: Advantages tensor of shape [batch_size, seq_len]. """ - prompt_ids = repeated_batch["_input_ids_for_baseline"] - rewards = repeated_batch["total_reward"] baseline, std = calculate_baseline_and_std_per_prompt( prompt_ids, rewards, @@ -80,13 +86,21 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] self.normalize_rewards = estimator_config["normalize_rewards"] - def compute_advantage(self, repeated_batch, mask, **kwargs): + def compute_advantage( + self, + prompt_ids, + rewards, + repeated_batch, + mask, + **kwargs, + ): """Compute GDPO advantages. Args: + prompt_ids: Unused; for interface consistency. + rewards: Unused; for interface consistency. repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys. mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. - Used only for expanding advantages to token-level shape. **kwargs: Additional arguments (unused). Returns: @@ -153,23 +167,30 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.kl_coef = loss_config["reference_policy_kl_penalty"] self.kl_type = loss_config["reference_policy_kl_type"] - def compute_advantage(self, repeated_batch, mask, logprobs_policy=None, logprobs_reference=None, **kwargs): + def compute_advantage( + self, + prompt_ids, + rewards, + repeated_batch, + mask, + logprobs_policy=None, + logprobs_reference=None, + **kwargs, + ): """Compute Reinforce++ advantages with optional KL penalty. Args: - repeated_batch: Batch containing _input_ids_for_baseline and total_reward. + prompt_ids: Tensor identifying which prompt each sample belongs to (for baseline). + rewards: Tensor of shape [batch_size] containing reward for each sample. + repeated_batch: Unused; for interface consistency. mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. - Used for: (1) expanding advantages to token-level shape, (2) global normalization - that only considers valid tokens. - logprobs_policy: Policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. - logprobs_reference: Reference policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. + logprobs_policy: Policy log probabilities, required if use_kl_in_reward. + logprobs_reference: Reference policy log probabilities, required if use_kl_in_reward. **kwargs: Additional arguments (unused). Returns: Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. """ - prompt_ids = repeated_batch["_input_ids_for_baseline"] - rewards = repeated_batch["total_reward"] # minus baseline if self.minus_baseline: mean, _ = calculate_baseline_and_std_per_prompt( diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f5b8892ed7..a5ec44fc69 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1795,6 +1795,8 @@ def grpo_train( mask = token_mask * sample_mask.unsqueeze(-1) train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, repeated_batch=repeated_batch, mask=mask, logprobs_policy=train_data["prev_logprobs"], @@ -2827,6 +2829,8 @@ def async_grpo_train( mask = token_mask * sample_mask.unsqueeze(-1) train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, repeated_batch=repeated_batch, mask=mask, logprobs_policy=train_data["prev_logprobs"], diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 52f3996244..a467f3f0b4 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -328,8 +328,14 @@ def calculate_rewards( range(len(all_indices_order)), key=lambda k: all_indices_order[k] ) # Stack rewards: each element may be scalar (single-reward env) or 1d (multi-reward env). - # torch.stack preserves shape: scalars -> (N,), shape (K,) -> (N, K). - rewards = torch.stack([all_rewards[i] for i in sorted_indices]) + # Envs may return Python floats or tensors; ensure tensors for torch.stack. + # Handle empty batch (torch.stack requires non-empty list). + if not sorted_indices: + rewards = torch.tensor([], dtype=torch.float32) + else: + rewards = torch.stack( + [torch.as_tensor(all_rewards[i]) for i in sorted_indices] + ) env_observations = [all_env_observations[i] for i in sorted_indices] terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices]) next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] @@ -476,7 +482,7 @@ def run_multi_turn_rollout( number_of_rewards = 1 # multi_rewards left None: GRPO uses total_reward only; reward1 unused # Accumulate rewards: env may return shape (N,) or (N, K) - if env_output.rewards.ndim >= 2: + if number_of_rewards > 1: multi_rewards[active_indices] += env_output.rewards total_rewards[active_indices] += env_output.rewards.sum(dim=1) else: diff --git a/pyproject.toml b/pyproject.toml index 5ba0782a0e..bb00194adb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ requires-python = ">=3.12" license = { text = "Apache 2.0" } dependencies = [ "setuptools", - "pip", # Required for frozen environments; uv venv --seed may not reliably install pip - "ninja", # for flash-attn parallel build + "pip", # Required for frozen environments; uv venv --seed may not reliably install pip + "ninja", # for flash-attn parallel build "torch==2.9.0", "triton; sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'aarch64')", "colored==2.2.3", @@ -44,16 +44,17 @@ dependencies = [ "sympy>=1.14.0", "pillow>=11.3.0", "torchvision>=0.22.0", - "num2words>=0.5.14", # for SmolVLM + "num2words>=0.5.14", # for SmolVLM "mlflow>=3.5.0,<3.6.0", "nvidia-nvshmem-cu12; sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'aarch64')", # for deep_ep build "swanlab", "pyzmq", "decord2", "nvidia-resiliency-ext", - "nccl4py", # for non-colocated refit - "cuda-bindings", # for non-colocated refit - "pybase64", # for sglang refit + "nccl4py", # for non-colocated refit + "cuda-bindings", # for non-colocated refit + "pybase64", # for sglang refit + "decord>=0.6.0", ] [project.optional-dependencies] diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index e0bdb8d2b1..dc6439d6f0 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -1644,7 +1644,12 @@ def test_grpo_advantage_estimator_zero_std(): mask = torch.ones(4, 5) repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(repeated_batch, mask) + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + repeated_batch=repeated_batch, + mask=mask, + ) # prompt 0: std=0 -> skip normalization, advantage=0 (reward - mean = 0) # prompt 1: With Bessel correction for 2 samples, std = sqrt(2), normalized = ±1/sqrt(2) ā‰ˆ ±0.7071 @@ -1676,7 +1681,12 @@ def test_grpo_advantage_estimator_tensor_shapes(): mask = torch.ones(2, 3) repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(repeated_batch, mask) + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + repeated_batch=repeated_batch, + mask=mask, + ) assert result.shape == (2, 3) # Verify normalized values: (reward - mean) / std @@ -1691,7 +1701,12 @@ def test_grpo_advantage_estimator_tensor_shapes(): mask = torch.ones(10, 5) repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(repeated_batch, mask) + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + repeated_batch=repeated_batch, + mask=mask, + ) assert result.shape == (10, 5) # After normalization, mean should be ~0 @@ -1717,7 +1732,12 @@ def test_grpo_advantage_estimator_negative_advantages(): mask = torch.ones(3, 4) repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(repeated_batch, mask) + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + repeated_batch=repeated_batch, + mask=mask, + ) # Verify ordering: first should be negative, middle ~0, last positive assert result[0, 0] < 0 # below mean -> negative advantage @@ -1748,7 +1768,12 @@ def test_grpo_advantage_estimator_zero_std_and_zero_advantage(): mask = torch.ones(4, 3) repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(repeated_batch, mask) + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + repeated_batch=repeated_batch, + mask=mask, + ) # All advantages should be exactly 0 expected = torch.zeros(4, 3) @@ -1775,7 +1800,12 @@ def test_grpo_advantage_estimator_small_nonzero_std(): mask = torch.ones(2, 3) repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(repeated_batch, mask) + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + repeated_batch=repeated_batch, + mask=mask, + ) # Even with small std, normalization should still happen # After normalization, the values should be ±1/sqrt(2) (for 2 samples with Bessel) @@ -1816,7 +1846,12 @@ def test_reinforce_plus_plus_global_normalization(): mask = torch.ones(4, 5) repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(repeated_batch, mask) + result = estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + repeated_batch=repeated_batch, + mask=mask, + ) # After global normalization, mean should be ~0 result_mean = (result * mask).sum() / mask.sum() diff --git a/uv.lock b/uv.lock index ade34b2f1b..e1c73b270e 100644 --- a/uv.lock +++ b/uv.lock @@ -1556,7 +1556,7 @@ name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'darwin' and extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-automodel' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-fsdp' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-mcore' and extra == 'extra-7-nemo-rl-sglang') or (sys_platform == 'linux' and extra == 'extra-7-nemo-rl-sglang' and extra == 'extra-7-nemo-rl-vllm')" }, + { name = "numpy" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -4696,6 +4696,7 @@ dependencies = [ { name = "cuda-bindings" }, { name = "datasets" }, { name = "debugpy" }, + { name = "decord" }, { name = "decord2" }, { name = "hydra-core" }, { name = "math-verify" }, @@ -4836,6 +4837,7 @@ requires-dist = [ { name = "cuda-python", marker = "extra == 'vllm'" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy" }, + { name = "decord", specifier = ">=0.6.0" }, { name = "decord2" }, { name = "deep-ep", marker = "extra == 'automodel'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, { name = "deep-ep", marker = "extra == 'mcore'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" },