diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index 756ce650b8..0395037a23 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -19,6 +19,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: ${rm.val_period} + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 83f367658d..b22f059f74 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -362,6 +362,9 @@ def validate_one_dataset( ): """Run validation on one validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -707,6 +710,8 @@ def dpo_train( current_step += 1 total_steps += 1 + if should_save_by_timeout: + return if total_steps >= master_config["dpo"]["max_num_steps"]: return diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 1b2a2ad50f..e9bc749d60 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -961,6 +961,8 @@ def grpo_train( timer.reset() current_step += 1 total_steps += 1 + if should_save_by_timeout: + break if total_steps >= max_num_steps: break @@ -978,6 +980,9 @@ def validate( ) -> tuple[dict[str, Any], dict[str, Any]]: """Run validation on the validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) return {}, {} diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index 640bc0823d..ad646d0021 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -40,7 +40,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import TimeoutChecker, Timer class RMSaveState(TypedDict): @@ -305,6 +305,9 @@ def validate_one_dataset( ): """Run validation on one validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -426,7 +429,11 @@ def rm_train( ): # Run basic rm training timer = Timer() - + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() if rm_save_state is None: rm_save_state = _default_rm_save_state() current_epoch = 0 @@ -512,13 +519,21 @@ def rm_train( ) ## Checkpointing + timeout.mark_iteration() + rm_save_state["consumed_samples"] += master_config["policy"][ "train_global_batch_size" ] - if master_config["checkpointing"]["enabled"] and ( + + should_save_by_step = ( is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0 + ) + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout ): ## +1 because step is 0-indexed rm_save_state["step"] = (current_step + 1) % len(train_dataloader) @@ -615,6 +630,8 @@ def rm_train( current_step += 1 total_steps += 1 + if should_save_by_timeout: + return if ( master_config["rm"]["max_num_steps"] != -1 and total_steps >= master_config["rm"]["max_num_steps"] diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 9bd69958ab..71853294d9 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -234,6 +234,9 @@ def validate( ): """Run validation on the validation dataset.""" if val_dataloader is None: + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" + ) print(" ⚠️ No validation dataloader provided, skipping validation") return @@ -578,6 +581,8 @@ def sft_train( current_step += 1 total_steps += 1 + if should_save_by_timeout: + return if total_steps >= master_config["sft"]["max_num_steps"]: return diff --git a/tests/unit/algorithms/test_rm.py b/tests/unit/algorithms/test_rm.py index a373cfc12f..8bfcabd81a 100644 --- a/tests/unit/algorithms/test_rm.py +++ b/tests/unit/algorithms/test_rm.py @@ -99,7 +99,11 @@ def val_iter(self): }, "train_micro_batch_size": 1, }, - "checkpointing": {"enabled": False}, + "checkpointing": { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 10, + }, } return {