From 416b9e8d02c46a29dd436f2e3e07ab6dda6131fb Mon Sep 17 00:00:00 2001 From: Wei Du Date: Thu, 4 Sep 2025 09:23:04 -0700 Subject: [PATCH 1/6] fix: stop jobs after timeout and add warning for validation Signed-off-by: Wei Du --- nemo_rl/algorithms/dpo.py | 5 +++++ nemo_rl/algorithms/grpo.py | 6 ++++++ nemo_rl/algorithms/sft.py | 5 +++++ 3 files changed, 16 insertions(+) diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 579099c530..5338466be8 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -334,6 +334,9 @@ def validate_one_dataset( ): """Run validation on one validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -640,6 +643,8 @@ def dpo_train( current_step += 1 total_steps += 1 + if should_save_by_timeout: + return if total_steps >= master_config["dpo"]["max_num_steps"]: return diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d3c1d4902c..ce5bf59a37 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -911,6 +911,9 @@ def grpo_train( timer.reset() step += 1 + + if should_save_by_timeout: + break if step >= master_config["grpo"]["max_num_steps"]: break @@ -925,6 +928,9 @@ def validate( ) -> tuple[dict[str, Any], dict[str, Any]]: """Run validation on the validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) return {}, {} diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 6de9ac81f1..3de8af278a 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -228,6 +228,9 @@ def validate( ): """Run validation on the validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -564,6 +567,8 @@ def sft_train( current_step += 1 total_steps += 1 + if should_save_by_timeout: + return if total_steps >= master_config["sft"]["max_num_steps"]: return From 5f8edeed853a89d6009a5fa9267b8dcab39130b3 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Thu, 4 Sep 2025 09:43:19 -0700 Subject: [PATCH 2/6] fix: stop jobs after timeout and add warning for validation Signed-off-by: Wei Du --- nemo_rl/algorithms/rm.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index b1aa9f01be..9bd4a6dd6f 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -40,7 +40,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import TimeoutChecker, Timer class RMSaveState(TypedDict): @@ -293,6 +293,9 @@ def validate_one_dataset( ): """Run validation on one validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -408,7 +411,11 @@ def rm_train( ): # Run basic rm training timer = Timer() - + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() if rm_save_state is None: rm_save_state = _default_rm_save_state() current_epoch = 0 @@ -494,13 +501,21 @@ def rm_train( ) ## Checkpointing + timeout.mark_iteration() + rm_save_state["consumed_samples"] += master_config["policy"][ "train_global_batch_size" ] - if master_config["checkpointing"]["enabled"] and ( + + should_save_by_step = ( is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0 + ) + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout ): ## +1 because step is 0-indexed rm_save_state["step"] = (current_step + 1) % len(train_dataloader) @@ -597,6 +612,8 @@ def rm_train( current_step += 1 total_steps += 1 + if should_save_by_timeout: + return if ( master_config["rm"]["max_num_steps"] != -1 and total_steps >= master_config["rm"]["max_num_steps"] From 7c63461b5e53edfb62134788c11ed636ccae630d Mon Sep 17 00:00:00 2001 From: Wei Du Date: Thu, 4 Sep 2025 09:47:03 -0700 Subject: [PATCH 3/6] fix: update yaml Signed-off-by: Wei Du --- examples/configs/rm.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index 744538d5ed..6377582455 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -19,6 +19,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: ${rm.val_period} + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" From b973491f5888a7750e378be63ebfcdb3a1fc9cd4 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Thu, 4 Sep 2025 15:32:00 -0700 Subject: [PATCH 4/6] fix: update test_rm.py for unit test Signed-off-by: Wei Du --- tests/unit/algorithms/test_rm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/algorithms/test_rm.py b/tests/unit/algorithms/test_rm.py index a373cfc12f..5aa63ce218 100644 --- a/tests/unit/algorithms/test_rm.py +++ b/tests/unit/algorithms/test_rm.py @@ -99,7 +99,10 @@ def val_iter(self): }, "train_micro_batch_size": 1, }, - "checkpointing": {"enabled": False}, + "checkpointing": { + "enabled": False, + "checkpoint_must_save_by": None, + }, } return { From 08b1d2245ab1d4cb933a12e02c5b31b8cfa1e278 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Wed, 10 Sep 2025 08:13:55 -0700 Subject: [PATCH 5/6] fix conflict Signed-off-by: Wei Du --- nemo_rl/algorithms/grpo.py | 722 ++++++++++++++++++++----------------- 1 file changed, 387 insertions(+), 335 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ce5bf59a37..e9bc749d60 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -77,6 +77,7 @@ class GRPOConfig(TypedDict): num_prompts_per_step: int num_generations_per_prompt: int + max_num_epochs: int max_num_steps: int max_rollout_turns: int normalize_rewards: bool @@ -90,18 +91,22 @@ class GRPOConfig(TypedDict): class GRPOSaveState(TypedDict): - step: int + consumed_samples: int + current_step: int + current_epoch: int + total_steps: int val_reward: NotRequired[ float ] # Optional field - may not be present during training - consumed_samples: int def _default_grpo_save_state() -> GRPOSaveState: return { - "step": 0, - "val_reward": -99999999.0, "consumed_samples": 0, + "current_step": 0, + "current_epoch": 0, + "total_steps": 0, + "val_reward": -99999999.0, } @@ -352,6 +357,11 @@ def setup( weights_path = None optimizer_path = None + if policy_config.get("megatron_cfg", {}).get("enabled", False): + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min(grpo_config["max_num_steps"], len(dataloader)) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + policy = Policy( cluster=train_cluster, config=policy_config, @@ -528,14 +538,24 @@ def grpo_train( assert policy_generation is not None # for mypy type check # common config/state itmes - step = grpo_save_state["step"] - consumed_samples = grpo_save_state["consumed_samples"] - val_period = master_config["grpo"]["val_period"] + current_step = grpo_save_state["current_step"] # current step within an epoch + total_steps = grpo_save_state["total_steps"] # total steps across all epochs + max_num_steps = master_config["grpo"][ + "max_num_steps" + ] # max number of steps to train for + current_epoch = grpo_save_state["current_epoch"] # current epoch + max_num_epochs = master_config["grpo"][ + "max_num_epochs" + ] # max number of epochs to train for + consumed_samples = grpo_save_state[ + "consumed_samples" + ] # total samples consumed across all epochs val_at_start = master_config["grpo"]["val_at_start"] + val_period = master_config["grpo"]["val_period"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] # Run validation at the start if configured - if val_at_start and step == 0: + if val_at_start and current_step == 0: print("\n🔍 Running initial validation...", flush=True) if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation(policy, policy_generation, colocated_inference) @@ -551,371 +571,403 @@ def grpo_train( master_config=master_config, ) policy_generation.finish_generation() - logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation") + logger.log_metrics(val_metrics, current_step, prefix="validation") + logger.log_metrics(validation_timings, current_step, prefix="timing/validation") - # Run grpo training (single-turn) - batch: BatchedDataDict[DatumSpec] - for batch in dataloader: - print( - f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}", - flush=True, - ) - maybe_gpu_profile_step(policy, step + 1) - if policy != policy_generation: - maybe_gpu_profile_step(policy_generation, step + 1) - val_metrics, validation_timings = None, None - - with timer.time("total_step_time"): - # Prepare batch - print("▶ Preparing batch...", flush=True) - with timer.time("data_processing"): - # Repeat batch items - repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave( - master_config["grpo"]["num_generations_per_prompt"] - ) - # Convert LLMMessageLogType to FlatMessagesType for generation - batched_flat, input_lengths = batched_message_log_to_flat_message( - repeated_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - input_ids = batched_flat["token_ids"] + while current_epoch < max_num_epochs and total_steps < max_num_steps: + print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") + # Run grpo training (single-turn) + batch: BatchedDataDict[DatumSpec] - # Generate responses - this updates the LLMMessageLogType in repeated_batch + for batch in dataloader: print( - f"▶ Generating responses for batch of size {repeated_batch.size}...", + f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_num_steps)} {'=' * 25}", flush=True, ) - with timer.time("prepare_for_generation"): - if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation( - policy, policy_generation, colocated_inference, timer=timer - ) - POLICY_GENERATION_STALE = False - else: - policy_generation.prepare_for_generation() - - with timer.time("generation"): - # Use async rollouts if vLLM async engine is enabled - if _should_use_async_rollouts(master_config): - ( - repeated_batch, - rollout_metrics, - ) = run_async_multi_turn_rollout( - policy_generation=policy_generation, - input_batch=repeated_batch, - tokenizer=tokenizer, - task_to_env=task_to_env, - max_seq_len=master_config["policy"][ - "max_total_sequence_length" - ], - max_rollout_turns=master_config["grpo"]["max_rollout_turns"], - greedy=False, + maybe_gpu_profile_step(policy, total_steps + 1) + if policy != policy_generation: + maybe_gpu_profile_step(policy_generation, total_steps + 1) + val_metrics, validation_timings = None, None + + with timer.time("total_step_time"): + # Prepare batch + print("▶ Preparing batch...", flush=True) + with timer.time("data_processing"): + # Repeat batch items + repeated_batch: BatchedDataDict[DatumSpec] = ( + batch.repeat_interleave( + master_config["grpo"]["num_generations_per_prompt"] + ) ) - else: - repeated_batch, rollout_metrics = run_multi_turn_rollout( - policy_generation=policy_generation, - input_batch=repeated_batch, - tokenizer=tokenizer, - task_to_env=task_to_env, - max_seq_len=master_config["policy"][ - "max_total_sequence_length" - ], - max_rollout_turns=master_config["grpo"]["max_rollout_turns"], - greedy=False, + # Convert LLMMessageLogType to FlatMessagesType for generation + batched_flat, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) - policy_generation.finish_generation() - - # Calculate rewards & advantages - print("▶ Processing rewards...", flush=True) - with timer.time("reward_calculation"): - # Extract rewards from final_batch - rewards = repeated_batch["total_reward"] - - print("▶ Computing advantages...", flush=True) - baseline, std = calculate_baseline_and_std_per_prompt( - input_ids, - rewards, - torch.ones_like(rewards), - leave_one_out_baseline=master_config["grpo"][ - "use_leave_one_out_baseline" - ], - ) - advantages = (rewards - baseline).unsqueeze(-1) + input_ids = batched_flat["token_ids"] - if master_config["grpo"]["normalize_rewards"]: - # don't sharpen the ones with no variation - zero_std_mask = std > 0 - advantages[zero_std_mask] = ( - advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] + # Generate responses - this updates the LLMMessageLogType in repeated_batch + print( + f"▶ Generating responses for batch of size {repeated_batch.size}...", + flush=True, + ) + with timer.time("prepare_for_generation/total"): + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, policy_generation, colocated_inference, timer=timer + ) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + + with timer.time("generation"): + # Use async rollouts if vLLM async engine is enabled + if _should_use_async_rollouts(master_config): + ( + repeated_batch, + rollout_metrics, + ) = run_async_multi_turn_rollout( + policy_generation=policy_generation, + input_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=master_config["policy"][ + "max_total_sequence_length" + ], + max_rollout_turns=master_config["grpo"][ + "max_rollout_turns" + ], + greedy=False, + ) + else: + repeated_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=policy_generation, + input_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=master_config["policy"][ + "max_total_sequence_length" + ], + max_rollout_turns=master_config["grpo"][ + "max_rollout_turns" + ], + greedy=False, + ) + policy_generation.finish_generation() + + # Calculate rewards & advantages + print("▶ Processing rewards...,", flush=True) + with timer.time("reward_calculation"): + # Extract rewards from final_batch + rewards = repeated_batch["total_reward"] + + print("▶ Computing advantages...", flush=True) + baseline, std = calculate_baseline_and_std_per_prompt( + input_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=master_config["grpo"][ + "use_leave_one_out_baseline" + ], ) + advantages = (rewards - baseline).unsqueeze(-1) - with timer.time("data_processing"): - use_overlong_filtering = master_config["grpo"]["overlong_filtering"] - if use_overlong_filtering: - loss_multiplier = repeated_batch["loss_multiplier"].clone() - truncated = repeated_batch["truncated"] - - if isinstance(truncated, list): - truncated = torch.tensor(truncated, dtype=torch.bool) - - loss_multiplier[truncated] = 0 - repeated_batch["loss_multiplier"] = loss_multiplier - # Add loss mask and advantages to each message in LLMMessageLogType - for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): - if message["role"] == "assistant": - message["token_loss_mask"] = torch.ones_like( - message["token_ids"] - ) - else: - message["token_loss_mask"] = torch.zeros_like( - message["token_ids"] - ) - if "generation_logprobs" not in message: - message["generation_logprobs"] = torch.zeros_like( - message["token_ids"], dtype=torch.float32 - ) - message["advantages"] = advantages[i].expand( - message["token_ids"].shape + if master_config["grpo"]["normalize_rewards"]: + # don't sharpen the ones with no variation + zero_std_mask = std > 0 + advantages[zero_std_mask] = ( + advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] ) - # Convert updated LLMMessageLogType to FlatMessagesType for training - flat_messages, input_lengths = batched_message_log_to_flat_message( - repeated_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - make_sequence_length_divisible_by=master_config["policy"][ - "make_sequence_length_divisible_by" - ], - ) + with timer.time("data_processing"): + use_overlong_filtering = master_config["grpo"]["overlong_filtering"] + if use_overlong_filtering: + loss_multiplier = repeated_batch["loss_multiplier"].clone() + truncated = repeated_batch["truncated"] + + if isinstance(truncated, list): + truncated = torch.tensor(truncated, dtype=torch.bool) + + loss_multiplier[truncated] = 0 + repeated_batch["loss_multiplier"] = loss_multiplier + # Add loss mask and advantages to each message in LLMMessageLogType + for i, message_log in enumerate(repeated_batch["message_log"]): + for j, message in enumerate(message_log): + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + message["token_ids"], dtype=torch.float32 + ) + message["advantages"] = advantages[i].expand( + message["token_ids"].shape + ) - # Create training data from flattened messages - train_data = BatchedDataDict[ClippedPGLossDataDict]( - { - "input_ids": flat_messages["token_ids"], - "input_lengths": input_lengths, - "advantages": flat_messages["advantages"], - "generation_logprobs": flat_messages["generation_logprobs"], - "token_mask": flat_messages["token_loss_mask"], - "sample_mask": repeated_batch["loss_multiplier"], - } - ) - # this will be mini-batched inside the policy, so maintain the packed multimodal structure - train_data.update(flat_messages.get_multimodal_dict(as_tensors=False)) - train_data.to("cpu") - - print("▶ Preparing for logprob inference...", flush=True) - with timer.time("logprob_inference_prep"): - policy.prepare_for_lp_inference() - - print("▶ Computing logprobs...", flush=True) - with timer.time("policy_and_reference_logprobs"): - fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] - reference_logprobs = policy.get_reference_policy_logprobs(train_data)[ - "reference_logprobs" - ] - train_data["prev_logprobs"] = fprop_logprobs - train_data["reference_policy_logprobs"] = reference_logprobs - - print("▶ Preparing for training...", flush=True) - with timer.time("training_prep"): - policy.prepare_for_training() # set model train and reload optim to GPU - POLICY_GENERATION_STALE = True - - print("▶ Training policy...", flush=True) - with timer.time("policy_training"): - train_results = policy.train(train_data, loss_fn) - - is_last_step = step + 1 == min( - master_config["grpo"]["max_num_steps"], len(dataloader) - ) + # Convert updated LLMMessageLogType to FlatMessagesType for training + flat_messages, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], + ) - # Run validation if it's a validation step - if val_period > 0 and (step + 1) % val_period == 0: - if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation( - policy, policy_generation, colocated_inference + # Create training data from flattened messages + train_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "advantages": flat_messages["advantages"], + "generation_logprobs": flat_messages["generation_logprobs"], + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": repeated_batch["loss_multiplier"], + } ) - POLICY_GENERATION_STALE = False - else: - policy_generation.prepare_for_generation() - val_metrics, validation_timings = validate( - policy_generation, - val_dataloader, - tokenizer, - val_task_to_env, - step=step + 1, - master_config=master_config, - ) - policy_generation.finish_generation() - logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" + # this will be mini-batched inside the policy, so maintain the packed multimodal structure + train_data.update( + flat_messages.get_multimodal_dict(as_tensors=False) + ) + train_data.to("cpu") + + print("▶ Preparing for logprob inference...", flush=True) + with timer.time("logprob_inference_prep"): + policy.prepare_for_lp_inference() + + print("▶ Computing logprobs...", flush=True) + with timer.time("policy_and_reference_logprobs"): + fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] + reference_logprobs = policy.get_reference_policy_logprobs( + train_data + )["reference_logprobs"] + train_data["prev_logprobs"] = fprop_logprobs + train_data["reference_policy_logprobs"] = reference_logprobs + + print("▶ Preparing for training...", flush=True) + with timer.time("training_prep"): + policy.prepare_for_training() # set model train and reload optim to GPU + POLICY_GENERATION_STALE = True + + print("▶ Training policy...", flush=True) + with timer.time("policy_training"): + train_results = policy.train(train_data, loss_fn) + + is_last_step = (total_steps + 1 >= max_num_steps) or ( + (current_epoch + 1 == max_num_epochs) + and (current_step + 1 == len(dataloader)) ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") - - ## Checkpointing - consumed_samples += master_config["grpo"]["num_prompts_per_step"] - timeout.mark_iteration() - - should_save_by_step = ( - is_last_step - or (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ) - # +1 because step is 0-indexed - # Check if timeout-based checkpointing is enabled in config. - should_save_by_timeout = timeout.check_save() - if master_config["checkpointing"]["enabled"] and ( - should_save_by_step or should_save_by_timeout - ): - policy.prepare_for_training() - - grpo_save_state["step"] = step + 1 - if val_metrics is not None: - grpo_save_state["val_reward"] = val_metrics["accuracy"] - elif "val_reward" in grpo_save_state: - del grpo_save_state["val_reward"] - grpo_save_state["consumed_samples"] = consumed_samples - - if master_config["checkpointing"]["metric_name"] is not None: - if ( - master_config["checkpointing"]["metric_name"] - not in grpo_save_state - ): - warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." + # Run validation if it's a validation step + if val_period > 0 and (total_steps + 1) % val_period == 0: + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, policy_generation, colocated_inference ) - master_config["checkpointing"]["metric_name"] = None - - with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...", flush=True) - checkpoint_path = checkpointer.init_tmp_checkpoint( - step + 1, grpo_save_state, master_config + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=total_steps + 1, + master_config=master_config, ) - policy.save_checkpoint( - weights_path=os.path.join(checkpoint_path, "policy", "weights"), - optimizer_path=os.path.join( - checkpoint_path, "policy", "optimizer" - ), - tokenizer_path=os.path.join( - checkpoint_path, "policy", "tokenizer" - ), + policy_generation.finish_generation() + logger.log_metrics( + validation_timings, total_steps + 1, prefix="timing/validation" ) - torch.save( - dataloader.state_dict(), - os.path.join(checkpoint_path, "train_dataloader.pt"), + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" ) - checkpointer.finalize_checkpoint(checkpoint_path) - - # Logging - # Log training data - log_data = {"content": flat_messages["content"]} - log_data["rewards"] = rewards.tolist() - log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() - log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() - log_data["input_lengths"] = input_lengths.tolist() - logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl") - - metrics = { - "loss": train_results["loss"].numpy(), - "reward": rewards.numpy(), - "grad_norm": train_results["grad_norm"].numpy(), - "mean_prompt_length": repeated_batch["length"].numpy(), - "total_num_tokens": input_lengths.numpy(), - } - metrics.update(train_results["all_mb_metrics"]) - for k, v in metrics.items(): - if k in { - "lr", - "wd", - "reward", - "global_valid_seqs", - "global_valid_toks", - "mean_prompt_length", - }: - metrics[k] = np.mean(v).item() - else: - metrics[k] = np.sum(v).item() - metrics.update(rollout_metrics) - timing_metrics: dict[str, float] = timer.get_timing_metrics(reduction_op="sum") # type: ignore - # track example with high token mult prob error above 1.05 - if metrics["token_mult_prob_error"] > 1.05: - logger.log_plot_token_mult_prob_error( - { - "prompt_lengths": repeated_batch["length"], - "full_lengths": input_lengths, - "generation_logprobs": train_data["generation_logprobs"], - "prev_logprobs": train_data["prev_logprobs"], - "token_mask": train_data["token_mask"], - "sample_mask": train_data["sample_mask"], - }, - step + 1, - name="train/token_mult_prob_error_plot_sample", - ) + ## Checkpointing + consumed_samples += master_config["grpo"]["num_prompts_per_step"] + timeout.mark_iteration() - print("\n📊 Training Results:") + should_save_by_step = ( + is_last_step + or (total_steps + 1) % master_config["checkpointing"]["save_period"] + == 0 + ) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + policy.prepare_for_training() + + # +1 because step is 0-indexed + grpo_save_state["current_step"] = current_step + 1 + grpo_save_state["total_steps"] = total_steps + 1 + grpo_save_state["current_epoch"] = current_epoch + if val_metrics is not None: + grpo_save_state["val_reward"] = val_metrics["accuracy"] + elif "val_reward" in grpo_save_state: + del grpo_save_state["val_reward"] + grpo_save_state["consumed_samples"] = consumed_samples + + if master_config["checkpointing"]["metric_name"] is not None: + if ( + master_config["checkpointing"]["metric_name"] + not in grpo_save_state + ): + warnings.warn( + f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " + "Saving most recent k checkpoints instead." + ) + master_config["checkpointing"]["metric_name"] = None - print(f" • Loss: {metrics['loss']:.4f}") - print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") - print( - f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}", - flush=True, - ) - if "total_flops" in train_results: - total_tflops = ( - train_results["total_flops"] / timing_metrics["policy_training"] / 1e12 + with timer.time("checkpointing"): + print( + f"Saving checkpoint for step {total_steps + 1}...", + flush=True, + ) + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, grpo_save_state, master_config + ) + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + ) + torch.save( + dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + # Logging + # Log training data + log_data = {"content": flat_messages["content"]} + log_data["rewards"] = rewards.tolist() + log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() + log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() + log_data["input_lengths"] = input_lengths.tolist() + logger.log_batched_dict_as_jsonl( + log_data, f"train_data_step{total_steps}.jsonl" ) - num_ranks = train_results["num_ranks"] + + metrics = { + "loss": train_results["loss"].numpy(), + "reward": rewards.numpy(), + "grad_norm": train_results["grad_norm"].numpy(), + "mean_prompt_length": repeated_batch["length"].numpy(), + "total_num_tokens": input_lengths.numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k in { + "lr", + "wd", + "reward", + "global_valid_seqs", + "global_valid_toks", + "mean_prompt_length", + }: + metrics[k] = np.mean(v).item() + else: + metrics[k] = np.sum(v).item() + metrics.update(rollout_metrics) + + timing_metrics: dict[str, float] = timer.get_timing_metrics( + reduction_op="sum" + ) # type: ignore + # track example with high token mult prob error above 1.05 + if metrics["token_mult_prob_error"] > 1.05: + logger.log_plot_token_mult_prob_error( + { + "prompt_lengths": repeated_batch["length"], + "full_lengths": input_lengths, + "generation_logprobs": train_data["generation_logprobs"], + "prev_logprobs": train_data["prev_logprobs"], + "token_mask": train_data["token_mask"], + "sample_mask": train_data["sample_mask"], + }, + total_steps + 1, + name="train/token_mult_prob_error_plot_sample", + ) + + print("\n📊 Training Results:") + + print(f" • Loss: {metrics['loss']:.4f}") + print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print( - f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", + f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}", flush=True, ) - if "theoretical_tflops" in train_results: - theoretical_tflops = train_results["theoretical_tflops"] + if "total_flops" in train_results: + total_tflops = ( + train_results["total_flops"] + / timing_metrics["policy_training"] + / 1e12 + ) + num_ranks = train_results["num_ranks"] print( - f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", flush=True, ) - metrics["train_fp_utilization"] = total_tflops / theoretical_tflops + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", + flush=True, + ) + metrics["train_fp_utilization"] = total_tflops / theoretical_tflops - print("\n⏱️ Timing:", flush=True) - # Display total time first, separately - total_time = timing_metrics.get("total_step_time", 0) + print("\n⏱️ Timing:", flush=True) + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) - total_num_gpus = ( - master_config["cluster"]["num_nodes"] - * master_config["cluster"]["gpus_per_node"] - ) - metrics.update( - { - "tokens_per_sec_per_gpu": metrics["total_num_tokens"] - / total_time - / total_num_gpus - } - ) + total_num_gpus = ( + master_config["cluster"]["num_nodes"] + * master_config["cluster"]["gpus_per_node"] + ) + metrics.update( + { + "tokens_per_sec_per_gpu": metrics["total_num_tokens"] + / total_time + / total_num_gpus + } + ) - print(f" • Total step time: {total_time:.2f}s", flush=True) + print(f" • Total step time: {total_time:.2f}s", flush=True) - # Display all other timing metrics - for k, v in sorted( - timing_metrics.items(), key=lambda item: item[1], reverse=True - ): - if k != "total_step_time": - percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) + # Display all other timing metrics + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) - logger.log_metrics(metrics, step + 1, prefix="train") - logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") - timer.reset() - step += 1 + timer.reset() + current_step += 1 + total_steps += 1 + if should_save_by_timeout: + break + if total_steps >= max_num_steps: + break - if should_save_by_timeout: - break - if step >= master_config["grpo"]["max_num_steps"]: - break + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch def validate( From a15a0feddf0189be7b7d3af6fd3782566a1bc02a Mon Sep 17 00:00:00 2001 From: Wei Du Date: Fri, 12 Sep 2025 07:43:02 -0700 Subject: [PATCH 6/6] fix unit test bug Signed-off-by: Wei Du --- tests/unit/algorithms/test_rm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/algorithms/test_rm.py b/tests/unit/algorithms/test_rm.py index 5aa63ce218..8bfcabd81a 100644 --- a/tests/unit/algorithms/test_rm.py +++ b/tests/unit/algorithms/test_rm.py @@ -102,6 +102,7 @@ def val_iter(self): "checkpointing": { "enabled": False, "checkpoint_must_save_by": None, + "save_period": 10, }, }