Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,6 @@ def do_fsdp(model):
else:
self.optimizer = None

# restore
if weights_path:
self.load_checkpoint(weights_path, optimizer_path)
else:
print(
"No weights path provided. Starting from scratch (default policy init)"
)

if "scheduler" in self.cfg:
if isinstance(self.cfg["scheduler"], dict):
scheduler_cls = import_class_from_path(self.cfg["scheduler"]["name"])
Expand Down Expand Up @@ -174,6 +166,14 @@ def do_fsdp(model):
self.optimizer, lr_lambda=lambda epoch: 1
)

# restore
if weights_path:
self.load_checkpoint(weights_path, optimizer_path)
else:
print(
"No weights path provided. Starting from scratch (default policy init)"
)

def is_alive(self):
return True

Expand Down Expand Up @@ -758,6 +758,12 @@ def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non
optim_state_dict = FullyShardedDataParallel.optim_state_dict(
self.model, self.optimizer
)
scheduler_state_dict = self.scheduler.state_dict()

optim_and_scheduler_state_dict = {
"optimizer": optim_state_dict,
"scheduler": scheduler_state_dict,
}

if torch.distributed.get_rank() == 0:
# check if weights_path dir exists
Expand All @@ -769,18 +775,20 @@ def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non
os.makedirs(weights_dir)
torch.save(model_state_dict, weights_path)
if optimizer_path is not None:
torch.save(optim_state_dict, optimizer_path)
torch.save(optim_and_scheduler_state_dict, optimizer_path)

def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None):
print(f"Loading Policy from {weights_path} and optimizer from {optimizer_path}")
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)

state_dict = torch.load(weights_path)
if optimizer_path is not None:
optimizer_state_dict = torch.load(optimizer_path)
optim_data = torch.load(optimizer_path)
optimizer_state_dict = optim_data["optimizer"]
scheduler_state_dict = optim_data.get("scheduler")
else:
optimizer_state_dict = None

scheduler_state_dict = None
with FullyShardedDataParallel.state_dict_type(
self.model,
state_dict_type=StateDictType.FULL_STATE_DICT,
Expand All @@ -801,6 +809,11 @@ def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non
else:
print("WARNING: No optimizer checkpoint provided")

if scheduler_state_dict is not None:
self.scheduler.load_state_dict(scheduler_state_dict)
else:
print("WARNING: No scheduler checkpoint provided")


class HfPolicy(PolicyInterface, GenerationInterface):
def __init__(
Expand All @@ -813,6 +826,11 @@ def __init__(
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
):
if weights_path:
weights_path = os.path.abspath(weights_path)
if optimizer_path:
optimizer_path = os.path.abspath(optimizer_path)

worker_builder = RayWorkerBuilder(
HfPolicyWorker,
config,
Expand Down
2 changes: 1 addition & 1 deletion nemo_reinforcer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def init_tmp_checkpoint(
with open(save_dir / "config.json", "w") as f:
json.dump(run_config, f)

return save_dir
return Path(os.path.abspath(save_dir))

def finalize_checkpoint(self, checkpoint_path: os.PathLike) -> None:
"""Complete a checkpoint by moving it from temporary to permanent location.
Expand Down
Loading