From ad32cb6b685f0d1fd39917063ad5ce56f3423ff6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 20:16:45 -0700 Subject: [PATCH] make checkpoint paths absolute, save lr scheduler state to checkpoint Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 40 ++++++++++++++++------ nemo_reinforcer/utils/checkpoint.py | 2 +- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 627c221cee..cf908ee70a 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -134,14 +134,6 @@ def do_fsdp(model): else: self.optimizer = None - # restore - if weights_path: - self.load_checkpoint(weights_path, optimizer_path) - else: - print( - "No weights path provided. Starting from scratch (default policy init)" - ) - if "scheduler" in self.cfg: if isinstance(self.cfg["scheduler"], dict): scheduler_cls = import_class_from_path(self.cfg["scheduler"]["name"]) @@ -174,6 +166,14 @@ def do_fsdp(model): self.optimizer, lr_lambda=lambda epoch: 1 ) + # restore + if weights_path: + self.load_checkpoint(weights_path, optimizer_path) + else: + print( + "No weights path provided. Starting from scratch (default policy init)" + ) + def is_alive(self): return True @@ -758,6 +758,12 @@ def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non optim_state_dict = FullyShardedDataParallel.optim_state_dict( self.model, self.optimizer ) + scheduler_state_dict = self.scheduler.state_dict() + + optim_and_scheduler_state_dict = { + "optimizer": optim_state_dict, + "scheduler": scheduler_state_dict, + } if torch.distributed.get_rank() == 0: # check if weights_path dir exists @@ -769,7 +775,7 @@ def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non os.makedirs(weights_dir) torch.save(model_state_dict, weights_path) if optimizer_path is not None: - torch.save(optim_state_dict, optimizer_path) + torch.save(optim_and_scheduler_state_dict, optimizer_path) def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): print(f"Loading Policy from {weights_path} and optimizer from {optimizer_path}") @@ -777,10 +783,12 @@ def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non state_dict = torch.load(weights_path) if optimizer_path is not None: - optimizer_state_dict = torch.load(optimizer_path) + optim_data = torch.load(optimizer_path) + optimizer_state_dict = optim_data["optimizer"] + scheduler_state_dict = optim_data.get("scheduler") else: optimizer_state_dict = None - + scheduler_state_dict = None with FullyShardedDataParallel.state_dict_type( self.model, state_dict_type=StateDictType.FULL_STATE_DICT, @@ -801,6 +809,11 @@ def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non else: print("WARNING: No optimizer checkpoint provided") + if scheduler_state_dict is not None: + self.scheduler.load_state_dict(scheduler_state_dict) + else: + print("WARNING: No scheduler checkpoint provided") + class HfPolicy(PolicyInterface, GenerationInterface): def __init__( @@ -813,6 +826,11 @@ def __init__( weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, ): + if weights_path: + weights_path = os.path.abspath(weights_path) + if optimizer_path: + optimizer_path = os.path.abspath(optimizer_path) + worker_builder = RayWorkerBuilder( HfPolicyWorker, config, diff --git a/nemo_reinforcer/utils/checkpoint.py b/nemo_reinforcer/utils/checkpoint.py index 80344b3cda..2425996400 100644 --- a/nemo_reinforcer/utils/checkpoint.py +++ b/nemo_reinforcer/utils/checkpoint.py @@ -116,7 +116,7 @@ def init_tmp_checkpoint( with open(save_dir / "config.json", "w") as f: json.dump(run_config, f) - return save_dir + return Path(os.path.abspath(save_dir)) def finalize_checkpoint(self, checkpoint_path: os.PathLike) -> None: """Complete a checkpoint by moving it from temporary to permanent location.