From 6b10bec9128492558b67214463b834c0b7786d50 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 26 Mar 2025 15:14:53 -0700 Subject: [PATCH 01/33] switch to torch distributed checkpointing Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/grpo.py | 6 +- nemo_reinforcer/algorithms/sft.py | 11 +- nemo_reinforcer/models/policy/hf_policy.py | 115 +++++---------------- nemo_reinforcer/utils/checkpoint.py | 40 +++++++ 4 files changed, 69 insertions(+), 103 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index e08c848522..4907bb7eb1 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -241,10 +241,7 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - weights_path=Path(last_checkpoint_path) / "policy.pt" - if last_checkpoint_path - else None, - optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + checkpoint_dir=Path(last_checkpoint_path) / "policy" if last_checkpoint_path else None, init_optimizer=True, @@ -623,7 +620,6 @@ def grpo_train( ) policy.save_checkpoint( os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), ) torch.save( dataloader.state_dict(), diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 0d06ad6366..85526fb8e7 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -175,10 +175,7 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - weights_path=Path(last_checkpoint_path) / "policy.pt" - if last_checkpoint_path - else None, - optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + checkpoint_dir=Path(last_checkpoint_path) / "policy" if last_checkpoint_path else None, init_optimizer=True, @@ -406,11 +403,7 @@ def sft_train( step + 1, sft_save_state, master_config ) policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), - ## NOTE: below is a workaround to avoid a bug with checkpointing - ## this should be removed once the bug is fixed - offload_to_cpu=False, + os.path.join(checkpoint_path, "policy"), ) torch.save( train_dataloader.state_dict(), diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 34a621063d..143591e391 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -28,6 +28,8 @@ StateDictType, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from transformers import AutoModelForCausalLM, AutoTokenizer from nemo_reinforcer.algorithms.interfaces import LossFunction @@ -46,6 +48,7 @@ from nemo_reinforcer.distributed.virtual_cluster import ( PY_EXECUTABLES, ) +from nemo_reinforcer.utils.checkpoint import ExpState @ray.remote @@ -65,8 +68,7 @@ def __repr__(self): def __init__( self, config: PolicyConfig, - weights_path: Optional[str] = None, - optimizer_path: Optional[str] = None, + checkpoint_dir: Optional[str] = None, init_optimizer: bool = True, ): self.cfg = config @@ -170,8 +172,8 @@ def do_fsdp(model): ) # restore - if weights_path: - self.load_checkpoint(weights_path, optimizer_path) + if checkpoint_dir: + self.load_checkpoint(checkpoint_dir) else: print( "No weights path provided. Starting from scratch (default policy init)" @@ -786,80 +788,23 @@ def move_to_cpu(self, model): def save_checkpoint( self, - weights_path: str, - optimizer_path: Optional[str] = None, - offload_to_cpu: bool = True, + save_path: str, ): - # Config to save full state dict on rank 0, offloaded to CPU - state_dict_config = FullStateDictConfig( - offload_to_cpu=offload_to_cpu, rank0_only=True - ) - - with FullyShardedDataParallel.state_dict_type( - self.model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=state_dict_config, - ): - # Save model state dict - model_state_dict = self.model.state_dict() - optim_state_dict = FullyShardedDataParallel.optim_state_dict( - self.model, self.optimizer - ) - scheduler_state_dict = self.scheduler.state_dict() - - optim_and_scheduler_state_dict = { - "optimizer": optim_state_dict, - "scheduler": scheduler_state_dict, - } - - if torch.distributed.get_rank() == 0: - # check if weights_path dir exists - weights_dir = os.path.dirname(weights_path) - if not os.path.exists(weights_dir): - print( - f"Creating weights directory {weights_dir} DOESN'T EXIST SOMEHOW" - ) - os.makedirs(weights_dir) - torch.save(model_state_dict, weights_path) - if optimizer_path is not None: - torch.save(optim_and_scheduler_state_dict, optimizer_path) - - def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): - print(f"Loading Policy from {weights_path} and optimizer from {optimizer_path}") - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - - state_dict = torch.load(weights_path) - if optimizer_path is not None: - optim_data = torch.load(optimizer_path) - optimizer_state_dict = optim_data["optimizer"] - scheduler_state_dict = optim_data.get("scheduler") - else: - optimizer_state_dict = None - scheduler_state_dict = None - with FullyShardedDataParallel.state_dict_type( - self.model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=state_dict_config, - ): - # Load model weights - self.model.load_state_dict(state_dict if state_dict else None) + state_dict = { + "experiment": ExpState(self.model, self.optimizer, self.scheduler) + } + dcp.save(state_dict, checkpoint_id=save_path) - # Load optimizer state - if optimizer_state_dict is not None: - optim_state_dict = FullyShardedDataParallel.shard_full_optim_state_dict( - optimizer_state_dict, self.model - ) - if self.optimizer is not None: - self.optimizer.load_state_dict(optim_state_dict) - else: - print("WARNING: initializing without optimizer") - else: - print("WARNING: No optimizer checkpoint provided") + def load_checkpoint(self, save_path: str): + print(f"Loading weights from {save_path}") - if scheduler_state_dict is not None: - self.scheduler.load_state_dict(scheduler_state_dict) - else: - print("WARNING: No scheduler checkpoint provided") + state_dict = { + "experiment": ExpState(self.model, self.optimizer, self.scheduler) + } + dcp.load( + state_dict=state_dict, + checkpoint_id=save_path, + ) def shutdown(self): """Shutdown the policy.""" @@ -875,20 +820,16 @@ def __init__( name_prefix: str = "hf_policy", workers_per_node: Optional[Union[int, List[int]]] = None, init_optimizer: bool = True, - weights_path: Optional[str] = None, - optimizer_path: Optional[str] = None, + checkpoint_dir: Optional[str] = None, ): - if weights_path: - weights_path = os.path.abspath(weights_path) - if optimizer_path: - optimizer_path = os.path.abspath(optimizer_path) + if checkpoint_dir: + checkpoint_dir = os.path.abspath(checkpoint_dir) worker_builder = RayWorkerBuilder( HfPolicyWorker, config, init_optimizer=init_optimizer, - weights_path=weights_path, - optimizer_path=optimizer_path, + checkpoint_dir=checkpoint_dir, ) self.worker_group = RayWorkerGroup( cluster, @@ -1073,16 +1014,12 @@ def offload_after_refit(self): def save_checkpoint( self, - weights_path: str, - optimizer_path: Optional[str] = None, - offload_to_cpu: bool = True, + save_path: str, ): """Save a checkpoint of the model.""" futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", - weights_path, - optimizer_path, - offload_to_cpu=offload_to_cpu, + save_path, respect_tied_workers=True, ) ray.get(futures) diff --git a/nemo_reinforcer/utils/checkpoint.py b/nemo_reinforcer/utils/checkpoint.py index 2425996400..5566ef2149 100644 --- a/nemo_reinforcer/utils/checkpoint.py +++ b/nemo_reinforcer/utils/checkpoint.py @@ -26,6 +26,46 @@ import torch import numpy as np +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict + + +## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html +class ExpState(Stateful): + """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs.""" + + def __init__(self, model, optimizer=None, scheduler=None): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + + def state_dict(self): + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict, optimizer_state_dict = get_state_dict( + self.model, + self.optimizer, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + return { + "model": model_state_dict, + "optim": optimizer_state_dict, + "sched": self.scheduler.state_dict(), + } + + def load_state_dict(self, state_dict): + # sets our state dicts on the model and optimizer, now that we've loaded + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + ) + + scheduler_state_dict = state_dict["sched"] + self.scheduler.load_state_dict(scheduler_state_dict) + class CheckpointingConfig(TypedDict): """Configuration for checkpoint management. From 83ca3a0075a829ef7e217b96c1fb90ff4123aad2 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 31 Mar 2025 15:34:20 -0700 Subject: [PATCH 02/33] save model weights and optimizer states to separate directories Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/grpo.py | 10 +++- nemo_reinforcer/algorithms/sft.py | 10 +++- nemo_reinforcer/models/policy/hf_policy.py | 61 +++++++++++++++------- nemo_reinforcer/utils/checkpoint.py | 43 ++++++++++++--- 4 files changed, 92 insertions(+), 32 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 4907bb7eb1..934e33d4b7 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -241,7 +241,10 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - checkpoint_dir=Path(last_checkpoint_path) / "policy" + weights_path=Path(last_checkpoint_path) / "policy" / "weights" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" if last_checkpoint_path else None, init_optimizer=True, @@ -619,7 +622,10 @@ def grpo_train( step + 1, grpo_save_state, master_config ) policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), + weights_path=os.path.join(checkpoint_path, "policy", "weights"), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), ) torch.save( dataloader.state_dict(), diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 85526fb8e7..a4dbe674d3 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -175,7 +175,10 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - checkpoint_dir=Path(last_checkpoint_path) / "policy" + weights_path=Path(last_checkpoint_path) / "policy" / "weights" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" if last_checkpoint_path else None, init_optimizer=True, @@ -403,7 +406,10 @@ def sft_train( step + 1, sft_save_state, master_config ) policy.save_checkpoint( - os.path.join(checkpoint_path, "policy"), + weights_path=os.path.join(checkpoint_path, "policy", "weights"), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), ) torch.save( train_dataloader.state_dict(), diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 143591e391..e5e55477e2 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -48,7 +48,7 @@ from nemo_reinforcer.distributed.virtual_cluster import ( PY_EXECUTABLES, ) -from nemo_reinforcer.utils.checkpoint import ExpState +from nemo_reinforcer.utils.checkpoint import ModelState, OptimizerState @ray.remote @@ -68,7 +68,8 @@ def __repr__(self): def __init__( self, config: PolicyConfig, - checkpoint_dir: Optional[str] = None, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, init_optimizer: bool = True, ): self.cfg = config @@ -172,8 +173,11 @@ def do_fsdp(model): ) # restore - if checkpoint_dir: - self.load_checkpoint(checkpoint_dir) + if weights_path: + self.load_checkpoint( + weights_path, + optimizer_path, + ) else: print( "No weights path provided. Starting from scratch (default policy init)" @@ -788,24 +792,36 @@ def move_to_cpu(self, model): def save_checkpoint( self, - save_path: str, + weights_path: str, + optimizer_path: str, ): - state_dict = { - "experiment": ExpState(self.model, self.optimizer, self.scheduler) + model_state_dict = {"model": ModelState(self.model)} + dcp.save(model_state_dict, checkpoint_id=weights_path) + + optimizer_state_dict = { + "optim": OptimizerState(self.model, self.optimizer, self.scheduler) } - dcp.save(state_dict, checkpoint_id=save_path) + dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) + + def load_checkpoint(self, weights_path: str, optimizer_path: str): + print(f"Loading weights from {weights_path}") - def load_checkpoint(self, save_path: str): - print(f"Loading weights from {save_path}") + model_state_dict = {"model": ModelState(self.model)} + dcp.load( + state_dict=model_state_dict, + checkpoint_id=weights_path, + ) - state_dict = { - "experiment": ExpState(self.model, self.optimizer, self.scheduler) + optimizer_state_dict = { + "optim": OptimizerState(self.model, self.optimizer, self.scheduler) } dcp.load( - state_dict=state_dict, - checkpoint_id=save_path, + state_dict=optimizer_state_dict, + checkpoint_id=optimizer_path, ) + print(f"{self.scheduler.state_dict()=}") + def shutdown(self): """Shutdown the policy.""" # @@ -820,16 +836,19 @@ def __init__( name_prefix: str = "hf_policy", workers_per_node: Optional[Union[int, List[int]]] = None, init_optimizer: bool = True, - checkpoint_dir: Optional[str] = None, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, ): - if checkpoint_dir: - checkpoint_dir = os.path.abspath(checkpoint_dir) + if weights_path: + weights_path = os.path.abspath(weights_path) + optimizer_path = os.path.abspath(optimizer_path) worker_builder = RayWorkerBuilder( HfPolicyWorker, config, init_optimizer=init_optimizer, - checkpoint_dir=checkpoint_dir, + weights_path=weights_path, + optimizer_path=optimizer_path, ) self.worker_group = RayWorkerGroup( cluster, @@ -1014,12 +1033,14 @@ def offload_after_refit(self): def save_checkpoint( self, - save_path: str, + weights_path: str, + optimizer_path: str, ): """Save a checkpoint of the model.""" futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", - save_path, + weights_path, + optimizer_path, respect_tied_workers=True, ) ray.get(futures) diff --git a/nemo_reinforcer/utils/checkpoint.py b/nemo_reinforcer/utils/checkpoint.py index 5566ef2149..e35fffa8f5 100644 --- a/nemo_reinforcer/utils/checkpoint.py +++ b/nemo_reinforcer/utils/checkpoint.py @@ -27,21 +27,50 @@ import numpy as np from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + set_model_state_dict, + get_optimizer_state_dict, + set_optimizer_state_dict, +) ## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html -class ExpState(Stateful): +class ModelState(Stateful): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs.""" - def __init__(self, model, optimizer=None, scheduler=None): + def __init__(self, model): + self.model = model + + def state_dict(self): + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict = get_model_state_dict( + self.model, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + return { + "model": model_state_dict, + } + + def load_state_dict(self, state_dict): + # sets our state dicts on the model and optimizer, now that we've loaded + set_model_state_dict( + self.model, + state_dict["model"], + ) + + +class OptimizerState(Stateful): + def __init__(self, model, optimizer, scheduler=None): self.model = model self.optimizer = optimizer self.scheduler = scheduler def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict, optimizer_state_dict = get_state_dict( + optimizer_state_dict = get_optimizer_state_dict( self.model, self.optimizer, options=torch.distributed.checkpoint.state_dict.StateDictOptions( @@ -49,18 +78,16 @@ def state_dict(self): ), ) return { - "model": model_state_dict, "optim": optimizer_state_dict, "sched": self.scheduler.state_dict(), } def load_state_dict(self, state_dict): # sets our state dicts on the model and optimizer, now that we've loaded - set_state_dict( + set_optimizer_state_dict( self.model, self.optimizer, - model_state_dict=state_dict["model"], - optim_state_dict=state_dict["optim"], + state_dict["optim"], ) scheduler_state_dict = state_dict["sched"] From 4fbb3130e80890783456db7341168808300e9749 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 31 Mar 2025 16:25:26 -0700 Subject: [PATCH 03/33] only load optimizer if path provided Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index e5e55477e2..b781477a88 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -812,15 +812,15 @@ def load_checkpoint(self, weights_path: str, optimizer_path: str): checkpoint_id=weights_path, ) - optimizer_state_dict = { - "optim": OptimizerState(self.model, self.optimizer, self.scheduler) - } - dcp.load( - state_dict=optimizer_state_dict, - checkpoint_id=optimizer_path, - ) - - print(f"{self.scheduler.state_dict()=}") + if optimizer_path: + print(f"Loading optimizer from {optimizer_path}") + optimizer_state_dict = { + "optim": OptimizerState(self.model, self.optimizer, self.scheduler) + } + dcp.load( + state_dict=optimizer_state_dict, + checkpoint_id=optimizer_path, + ) def shutdown(self): """Shutdown the policy.""" From 05bd71e0166a361e580ea0311eed3ab95c9a9f8a Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 31 Mar 2025 21:34:16 -0700 Subject: [PATCH 04/33] option to save model weights in hf format Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index b781477a88..dcd7706db2 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -794,7 +794,14 @@ def save_checkpoint( self, weights_path: str, optimizer_path: str, + save_hf: bool = False, ## whether to save the model in hf format ): + ## gathers the model weights and saves in HF format + ## note that HF format has no way to save optimizer state + if save_hf: + self.model.save_pretrained(weights_path) + return + model_state_dict = {"model": ModelState(self.model)} dcp.save(model_state_dict, checkpoint_id=weights_path) From 94d9227321d172c755d21734400f792d9c8fae02 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 10:16:13 -0700 Subject: [PATCH 05/33] address comments Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 16 +-- nemo_reinforcer/utils/checkpoint.py | 59 ---------- nemo_reinforcer/utils/hf_checkpoint.py | 129 +++++++++++++++++++++ 3 files changed, 138 insertions(+), 66 deletions(-) create mode 100644 nemo_reinforcer/utils/hf_checkpoint.py diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index dcd7706db2..2c11f583ca 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -48,7 +48,7 @@ from nemo_reinforcer.distributed.virtual_cluster import ( PY_EXECUTABLES, ) -from nemo_reinforcer.utils.checkpoint import ModelState, OptimizerState +from nemo_reinforcer.utils.hf_checkpoint import ModelState, OptimizerState @ray.remote @@ -793,7 +793,7 @@ def move_to_cpu(self, model): def save_checkpoint( self, weights_path: str, - optimizer_path: str, + optimizer_path: Optional[str] = None, save_hf: bool = False, ## whether to save the model in hf format ): ## gathers the model weights and saves in HF format @@ -805,12 +805,13 @@ def save_checkpoint( model_state_dict = {"model": ModelState(self.model)} dcp.save(model_state_dict, checkpoint_id=weights_path) - optimizer_state_dict = { - "optim": OptimizerState(self.model, self.optimizer, self.scheduler) - } - dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) + if optimizer_path: + optimizer_state_dict = { + "optim": OptimizerState(self.model, self.optimizer, self.scheduler) + } + dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) - def load_checkpoint(self, weights_path: str, optimizer_path: str): + def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): print(f"Loading weights from {weights_path}") model_state_dict = {"model": ModelState(self.model)} @@ -848,6 +849,7 @@ def __init__( ): if weights_path: weights_path = os.path.abspath(weights_path) + if optimizer_path: optimizer_path = os.path.abspath(optimizer_path) worker_builder = RayWorkerBuilder( diff --git a/nemo_reinforcer/utils/checkpoint.py b/nemo_reinforcer/utils/checkpoint.py index e35fffa8f5..4b184dd511 100644 --- a/nemo_reinforcer/utils/checkpoint.py +++ b/nemo_reinforcer/utils/checkpoint.py @@ -35,65 +35,6 @@ ) -## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html -class ModelState(Stateful): - """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs.""" - - def __init__(self, model): - self.model = model - - def state_dict(self): - # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - model_state_dict = get_model_state_dict( - self.model, - options=torch.distributed.checkpoint.state_dict.StateDictOptions( - cpu_offload=True - ), - ) - return { - "model": model_state_dict, - } - - def load_state_dict(self, state_dict): - # sets our state dicts on the model and optimizer, now that we've loaded - set_model_state_dict( - self.model, - state_dict["model"], - ) - - -class OptimizerState(Stateful): - def __init__(self, model, optimizer, scheduler=None): - self.model = model - self.optimizer = optimizer - self.scheduler = scheduler - - def state_dict(self): - # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT - optimizer_state_dict = get_optimizer_state_dict( - self.model, - self.optimizer, - options=torch.distributed.checkpoint.state_dict.StateDictOptions( - cpu_offload=True - ), - ) - return { - "optim": optimizer_state_dict, - "sched": self.scheduler.state_dict(), - } - - def load_state_dict(self, state_dict): - # sets our state dicts on the model and optimizer, now that we've loaded - set_optimizer_state_dict( - self.model, - self.optimizer, - state_dict["optim"], - ) - - scheduler_state_dict = state_dict["sched"] - self.scheduler.load_state_dict(scheduler_state_dict) - - class CheckpointingConfig(TypedDict): """Configuration for checkpoint management. diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py new file mode 100644 index 0000000000..95d480c639 --- /dev/null +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint management utilities for HF models.""" + +import os +import json +import glob +from typing import Dict, Any, Optional, List, Tuple, TypedDict +import shutil +from pathlib import Path +import torch +import numpy as np + +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + set_model_state_dict, + get_optimizer_state_dict, + set_optimizer_state_dict, +) + + +## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html +class ModelState(Stateful): + """Helper class for tracking model state in distributed checkpointing. + + This class is compliant with the Stateful protocol, allowing DCP to automatically + call state_dict/load_state_dict as needed in the dcp.save/load APIs. + + Args: + model: The PyTorch model to track. + """ + + def __init__(self, model): + self.model = model + + def state_dict(self): + """Get the model's state dictionary. + + Returns: + dict: Dictionary containing the model's state dict with CPU offloading enabled. + """ + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict = get_model_state_dict( + self.model, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + return { + "model": model_state_dict, + } + + def load_state_dict(self, state_dict): + """Load the state dictionary into the model. + + Args: + state_dict (dict): State dictionary to load. + """ + # sets our state dicts on the model and optimizer, now that we've loaded + set_model_state_dict( + self.model, + state_dict["model"], + ) + + +class OptimizerState(Stateful): + """Helper class for tracking optimizer state in distributed checkpointing. + + This class is compliant with the Stateful protocol, allowing DCP to automatically + call state_dict/load_state_dict as needed in the dcp.save/load APIs. + + Args: + model: The PyTorch model associated with the optimizer. + optimizer: The optimizer to track. + scheduler: Optional learning rate scheduler. + """ + + def __init__(self, model, optimizer, scheduler=None): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + + def state_dict(self): + """Get the optimizer and scheduler state dictionaries. + + Returns: + dict: Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. + """ + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + optimizer_state_dict = get_optimizer_state_dict( + self.model, + self.optimizer, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + return { + "optim": optimizer_state_dict, + "sched": self.scheduler.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load the state dictionaries into the optimizer and scheduler. + + Args: + state_dict (dict): State dictionary containing optimizer and scheduler states to load. + """ + # sets our state dicts on the model and optimizer, now that we've loaded + set_optimizer_state_dict( + self.model, + self.optimizer, + state_dict["optim"], + ) + + scheduler_state_dict = state_dict["sched"] + self.scheduler.load_state_dict(scheduler_state_dict) From 16f070f6fec9d1a13dacac5b7276f3fbf23d390b Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 11:49:30 -0700 Subject: [PATCH 06/33] refactor Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 55 ++++++--------- nemo_reinforcer/utils/hf_checkpoint.py | 81 ++++++++++++++++++++-- 2 files changed, 98 insertions(+), 38 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 2c11f583ca..cfbe54b798 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -48,7 +48,12 @@ from nemo_reinforcer.distributed.virtual_cluster import ( PY_EXECUTABLES, ) -from nemo_reinforcer.utils.hf_checkpoint import ModelState, OptimizerState +from nemo_reinforcer.utils.hf_checkpoint import ( + ModelState, + OptimizerState, + save_checkpoint, + load_checkpoint, +) @ray.remote @@ -794,42 +799,28 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, - save_hf: bool = False, ## whether to save the model in hf format + save_hf: bool = False, ): - ## gathers the model weights and saves in HF format - ## note that HF format has no way to save optimizer state - if save_hf: - self.model.save_pretrained(weights_path) - return - - model_state_dict = {"model": ModelState(self.model)} - dcp.save(model_state_dict, checkpoint_id=weights_path) - - if optimizer_path: - optimizer_state_dict = { - "optim": OptimizerState(self.model, self.optimizer, self.scheduler) - } - dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) + """Save a checkpoint of the model.""" + save_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, + save_hf=save_hf, + ) def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): - print(f"Loading weights from {weights_path}") - - model_state_dict = {"model": ModelState(self.model)} - dcp.load( - state_dict=model_state_dict, - checkpoint_id=weights_path, + """Load a checkpoint into the model.""" + load_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, ) - if optimizer_path: - print(f"Loading optimizer from {optimizer_path}") - optimizer_state_dict = { - "optim": OptimizerState(self.model, self.optimizer, self.scheduler) - } - dcp.load( - state_dict=optimizer_state_dict, - checkpoint_id=optimizer_path, - ) - def shutdown(self): """Shutdown the policy.""" # diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 95d480c639..0422ed5950 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -69,7 +69,7 @@ def load_state_dict(self, state_dict): Args: state_dict (dict): State dictionary to load. """ - # sets our state dicts on the model and optimizer, now that we've loaded + # sets our state dicts on the model, now that we've loaded set_model_state_dict( self.model, state_dict["model"], @@ -107,10 +107,14 @@ def state_dict(self): cpu_offload=True ), ) - return { + + state_dict = { "optim": optimizer_state_dict, - "sched": self.scheduler.state_dict(), } + if self.scheduler is not None: + state_dict["sched"] = self.scheduler.state_dict() + + return state_dict def load_state_dict(self, state_dict): """Load the state dictionaries into the optimizer and scheduler. @@ -118,12 +122,77 @@ def load_state_dict(self, state_dict): Args: state_dict (dict): State dictionary containing optimizer and scheduler states to load. """ - # sets our state dicts on the model and optimizer, now that we've loaded + # sets our state dicts on the optimizer, now that we've loaded set_optimizer_state_dict( self.model, self.optimizer, state_dict["optim"], ) - scheduler_state_dict = state_dict["sched"] - self.scheduler.load_state_dict(scheduler_state_dict) + ## load the scheduler state if it exists + if "sched" in state_dict: + self.scheduler.load_state_dict(state_dict["sched"]) + + +def save_checkpoint( + model, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, + save_hf: bool = False, +) -> None: + """Save a checkpoint of the model and optionally optimizer state. + + Args: + model: The PyTorch model to save + weights_path: Path to save model weights + optimizer: Optional optimizer to save + scheduler: Optional scheduler to save + optimizer_path: Path to save optimizer state (required if optimizer provided) + save_hf: Whether to save in HuggingFace format instead of DCP format + """ + if save_hf: + model.save_pretrained(weights_path) + return + + model_state_dict = {"model": ModelState(model)} + dcp.save(model_state_dict, checkpoint_id=weights_path) + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when saving optimizer state" + ) + optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)} + dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) + + +def load_checkpoint( + model, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, +) -> None: + """Load a checkpoint into the model and optionally optimizer. + + Args: + model: The PyTorch model to load into + weights_path: Path to load model weights from + optimizer: Optional optimizer to load state into + scheduler: Optional scheduler to load state into + optimizer_path: Path to load optimizer state from (required if optimizer provided) + """ + print(f"Loading weights from {weights_path}") + model_state_dict = {"model": ModelState(model)} + dcp.load(state_dict=model_state_dict, checkpoint_id=weights_path) + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when loading optimizer state" + ) + print(f"Loading optimizer from {optimizer_path}") + optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)} + dcp.load(state_dict=optimizer_state_dict, checkpoint_id=optimizer_path) From 7ad920f6d7f277ce4e28ad9018d6e2f5c474fff8 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 11:50:52 -0700 Subject: [PATCH 07/33] remove unused imports Signed-off-by: ashors1 --- nemo_reinforcer/utils/checkpoint.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/nemo_reinforcer/utils/checkpoint.py b/nemo_reinforcer/utils/checkpoint.py index 4b184dd511..2425996400 100644 --- a/nemo_reinforcer/utils/checkpoint.py +++ b/nemo_reinforcer/utils/checkpoint.py @@ -26,14 +26,6 @@ import torch import numpy as np -from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.checkpoint.state_dict import ( - get_model_state_dict, - set_model_state_dict, - get_optimizer_state_dict, - set_optimizer_state_dict, -) - class CheckpointingConfig(TypedDict): """Configuration for checkpoint management. From bb35d8e5cd51bd094cdc98054114fdbb58f3eeac Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 11:54:48 -0700 Subject: [PATCH 08/33] update imports Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 6 ------ nemo_reinforcer/utils/hf_checkpoint.py | 9 ++------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index cfbe54b798..02a93717d2 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -23,13 +23,9 @@ from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import ( FullyShardedDataParallel, - FullStateDictConfig, MixedPrecision, - StateDictType, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -import torch.distributed.checkpoint as dcp -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from transformers import AutoModelForCausalLM, AutoTokenizer from nemo_reinforcer.algorithms.interfaces import LossFunction @@ -49,8 +45,6 @@ PY_EXECUTABLES, ) from nemo_reinforcer.utils.hf_checkpoint import ( - ModelState, - OptimizerState, save_checkpoint, load_checkpoint, ) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 0422ed5950..070c06f647 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -14,15 +14,10 @@ """Checkpoint management utilities for HF models.""" -import os -import json -import glob -from typing import Dict, Any, Optional, List, Tuple, TypedDict -import shutil -from pathlib import Path +from typing import Any, Optional import torch -import numpy as np +import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, From 10a25a48ed356e4511f083b574420c4c6b8d76dd Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 12:04:02 -0700 Subject: [PATCH 09/33] support saving both HF and torch DCP checkpoints Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 2 ++ nemo_reinforcer/utils/hf_checkpoint.py | 29 ++++++++++++---------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 02a93717d2..1071a56287 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -793,6 +793,7 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, + save_torch_dist: bool = True, save_hf: bool = False, ): """Save a checkpoint of the model.""" @@ -802,6 +803,7 @@ def save_checkpoint( optimizer=self.optimizer if optimizer_path else None, scheduler=self.scheduler if optimizer_path else None, optimizer_path=optimizer_path, + save_torch_dist=save_torch_dist, save_hf=save_hf, ) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 070c06f647..7f2127a850 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -135,6 +135,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, optimizer_path: Optional[str] = None, + save_torch_dist: bool = True, save_hf: bool = False, ) -> None: """Save a checkpoint of the model and optionally optimizer state. @@ -148,19 +149,21 @@ def save_checkpoint( save_hf: Whether to save in HuggingFace format instead of DCP format """ if save_hf: - model.save_pretrained(weights_path) - return - - model_state_dict = {"model": ModelState(model)} - dcp.save(model_state_dict, checkpoint_id=weights_path) - - if optimizer is not None: - if optimizer_path is None: - raise ValueError( - "optimizer_path must be provided when saving optimizer state" - ) - optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)} - dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) + model.save_pretrained(os.path.join(weights_path, "hf_weights")) + + if save_torch_dist: + model_state_dict = {"model": ModelState(model)} + dcp.save(model_state_dict, checkpoint_id=weights_path) + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when saving optimizer state" + ) + optimizer_state_dict = { + "optim": OptimizerState(model, optimizer, scheduler) + } + dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) def load_checkpoint( From f712f66a0704616b2a3c6de6efcf732e1a9a401a Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 12:38:30 -0700 Subject: [PATCH 10/33] cleanup, add tokenizer_name to config Signed-off-by: ashors1 --- examples/configs/grpo_math_1B.yaml | 1 + examples/configs/grpo_math_8B.yaml | 1 + examples/configs/sft.yaml | 1 + examples/run_grpo_math.py | 2 +- examples/run_sft.py | 2 +- nemo_reinforcer/models/policy/hf_policy.py | 7 ++++++- nemo_reinforcer/utils/hf_checkpoint.py | 9 +++++++-- 7 files changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 72aad000ce..4483fb7eaa 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -24,6 +24,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 4 generation_batch_size: 32 diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index d747e9249f..9bcf8a5523 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -3,6 +3,7 @@ defaults: "grpo_math_1B.yaml" policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 1 generation_batch_size: 32 diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 1282285fc3..5e723f5ff9 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -18,6 +18,7 @@ checkpointing: policy: model_name: "meta-llama/Meta-Llama-3-8B" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 128 train_micro_batch_size: 1 max_total_sequence_length: 2048 diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 7e2f3e693d..21c7dbc503 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -187,7 +187,7 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs else: raise ValueError(f"No processor for dataset {data_config['dataset_name']}.") - tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + tokenizer = AutoTokenizer.from_pretrained(policy_config["tokenizer_name"]) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/examples/run_sft.py b/examples/run_sft.py index 950938b4da..cb8ad82ec2 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -100,7 +100,7 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig): val_dataset = data.formatted_ds["validation"] sft_task_spec = data.task_spec - tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + tokenizer = AutoTokenizer.from_pretrained(policy_config["tokenizer_name"]) train_dataset = AllTaskProcessedDataset( train_dataset, diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 1071a56287..c952994533 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -77,6 +77,7 @@ def __init__( rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] + tokenizer_name = self.cfg["tokenizer_name"] if self.cfg["precision"] == "float32": self.dtype = torch.float32 elif self.cfg["precision"] == "bfloat16": @@ -96,7 +97,7 @@ def __init__( torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # If no pad token is defined, you might need: if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token @@ -1031,12 +1032,16 @@ def save_checkpoint( self, weights_path: str, optimizer_path: str, + save_torch_dist: bool = True, + save_hf: bool = False, ): """Save a checkpoint of the model.""" futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", weights_path, optimizer_path, + save_torch_dist, + save_hf, respect_tied_workers=True, ) ray.get(futures) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 7f2127a850..3dc06feacc 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -14,6 +14,8 @@ """Checkpoint management utilities for HF models.""" +import os +from pathlib import Path from typing import Any, Optional import torch @@ -146,10 +148,13 @@ def save_checkpoint( optimizer: Optional optimizer to save scheduler: Optional scheduler to save optimizer_path: Path to save optimizer state (required if optimizer provided) - save_hf: Whether to save in HuggingFace format instead of DCP format + save_torch_dist: Whether to save in PyTorch distributed format + save_hf: Whether to save in HuggingFace format """ if save_hf: - model.save_pretrained(os.path.join(weights_path, "hf_weights")) + # Create a new path by appending "-hf" to the weights path + hf_weights_path = f"{Path(weights_path)}-hf" + model.save_pretrained(hf_weights_path) if save_torch_dist: model_state_dict = {"model": ModelState(model)} From e2c264c1011841a3a222808087eab75f7d6de025 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 14:16:19 -0700 Subject: [PATCH 11/33] add example conversion script Signed-off-by: ashors1 --- examples/convert_dcp_to_hf.py | 80 ++++++++++++++++++++++ nemo_reinforcer/models/policy/hf_policy.py | 6 +- 2 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 examples/convert_dcp_to_hf.py diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py new file mode 100644 index 0000000000..82ab2ea114 --- /dev/null +++ b/examples/convert_dcp_to_hf.py @@ -0,0 +1,80 @@ +import argparse +import os +from omegaconf import OmegaConf + +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.utils.config import load_config + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Convert Torch DCP checkpoint to HF checkpoint" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + parser.add_argument( + "--dcp-ckpt-path", type=str, default=None, help="Path to DCP checkpoint" + ) + parser.add_argument( + "--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint" + ) + # Parse known args for the script + args, remaining = parser.parse_known_args() + + # Convert remaining args to OmegaConf format + overrides = OmegaConf.from_dotlist(remaining) + + return args, overrides + + +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join(os.path.dirname(__file__), "configs", "sft.yaml") + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = OmegaConf.merge(config, overrides) + + dcp_ckpt = args.dcp_ckpt_path + hf_ckpt = args.hf_ckpt_path + + # Extract individual configs for easier access + policy_config = config["policy"] + cluster_config = config["cluster"] + + cluster = RayVirtualCluster( + name="sft_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1, + ) + + policy = HfPolicy( + cluster=cluster, + config=policy_config, + weights_path=dcp_ckpt, + init_optimizer=False, + ) + + policy.save_checkpoint( + weights_path=hf_ckpt, + save_hf=True, + save_torch_dist=False, + ) + print(f"Saved HF checkpoint to: {hf_ckpt}-hf") + + +if __name__ == "__main__": + main() diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index c952994533..2fb4cb044e 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -140,7 +140,7 @@ def do_fsdp(model): else: self.optimizer = None - if "scheduler" in self.cfg: + if "scheduler" in self.cfg and self.optimizer is not None: if isinstance(self.cfg["scheduler"], dict): scheduler_cls = import_class_from_path(self.cfg["scheduler"]["name"]) self.scheduler = scheduler_cls( @@ -166,7 +166,7 @@ def do_fsdp(model): self.optimizer, schedulers, milestones ) - else: + elif self.optimizer is not None: ## default to a passthrough LR schedule self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=lambda epoch: 1 @@ -1031,7 +1031,7 @@ def offload_after_refit(self): def save_checkpoint( self, weights_path: str, - optimizer_path: str, + optimizer_path: Optional[str] = None, save_torch_dist: bool = True, save_hf: bool = False, ): From 2affa7c233f77bd52e8bc7082d7054d634953446 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 15:18:22 -0700 Subject: [PATCH 12/33] copyright Signed-off-by: ashors1 --- examples/convert_dcp_to_hf.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py index 82ab2ea114..c2328ac60e 100644 --- a/examples/convert_dcp_to_hf.py +++ b/examples/convert_dcp_to_hf.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import os from omegaconf import OmegaConf From b05df3dcfef7e97afbcbd2596e342d35f8ebd73a Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 16:07:12 -0700 Subject: [PATCH 13/33] address comments Signed-off-by: ashors1 --- examples/convert_dcp_to_hf.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py index c2328ac60e..37d3d38fd9 100644 --- a/examples/convert_dcp_to_hf.py +++ b/examples/convert_dcp_to_hf.py @@ -16,7 +16,7 @@ import os from omegaconf import OmegaConf -from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster from nemo_reinforcer.models.policy.hf_policy import HfPolicy from nemo_reinforcer.utils.config import load_config @@ -27,7 +27,10 @@ def parse_args(): description="Convert Torch DCP checkpoint to HF checkpoint" ) parser.add_argument( - "--config", type=str, default=None, help="Path to YAML config file" + "--config", + type=str, + default=None, + help="Path to YAML config file used during model training", ) parser.add_argument( "--dcp-ckpt-path", type=str, default=None, help="Path to DCP checkpoint" @@ -66,8 +69,10 @@ def main(): policy_config = config["policy"] cluster_config = config["cluster"] + init_ray() + cluster = RayVirtualCluster( - name="sft_cluster", + name="convert_cluster", bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] * cluster_config["num_nodes"], use_gpus=True, From bb902c099b98c5c27c89823ae0e5f319b27cebba Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Tue, 1 Apr 2025 21:40:11 -0700 Subject: [PATCH 14/33] set is_main_process during hf save Signed-off-by: Anna Shors --- nemo_reinforcer/utils/hf_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 3dc06feacc..3f789e62d2 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -154,7 +154,7 @@ def save_checkpoint( if save_hf: # Create a new path by appending "-hf" to the weights path hf_weights_path = f"{Path(weights_path)}-hf" - model.save_pretrained(hf_weights_path) + model.save_pretrained(hf_weights_path, is_main_process=(torch.distributed.get_rank()==0)) if save_torch_dist: model_state_dict = {"model": ModelState(model)} From aac057fca28de34aed53722ff3dc56b470445241 Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Tue, 1 Apr 2025 22:10:08 -0700 Subject: [PATCH 15/33] save to absolute path Signed-off-by: Anna Shors --- examples/convert_dcp_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py index 37d3d38fd9..9c2b793f30 100644 --- a/examples/convert_dcp_to_hf.py +++ b/examples/convert_dcp_to_hf.py @@ -88,7 +88,7 @@ def main(): ) policy.save_checkpoint( - weights_path=hf_ckpt, + weights_path=os.path.abspath(hf_ckpt), save_hf=True, save_torch_dist=False, ) From 78c2a912b504074ff5317ed3465f5e83a15c2334 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 09:38:24 -0700 Subject: [PATCH 16/33] add documentation and save hf checkpoint at the end of training Signed-off-by: ashors1 --- docs/design_docs/checkpointing.md | 21 +++++++++++++++++++++ nemo_reinforcer/algorithms/grpo.py | 7 +++++++ nemo_reinforcer/algorithms/sft.py | 13 +++++++++---- 3 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 docs/design_docs/checkpointing.md diff --git a/docs/design_docs/checkpointing.md b/docs/design_docs/checkpointing.md new file mode 100644 index 0000000000..4daf433f08 --- /dev/null +++ b/docs/design_docs/checkpointing.md @@ -0,0 +1,21 @@ +# Checkpointing with HuggingFace Models + +## Checkpoint Format +Reinforcer provides two checkpoint formats for HuggingFace models: Torch distributed and HuggingFace format. Torch distributed is used by default for efficiency, and HuggingFace format is provided for compatibility with HuggingFace's `AutoModel.from_pretrained` API. Note that HuggingFace format checkpoints save only the model weights, ignoring the optimizer states. It is recommended to use Torch distributed format to save intermediate checkpoints and to save a HuggingFace checkpoint only at the end of training. + +There are two ways to get a Reinforcer checkpoint in HuggingFace format. + +1. (Recommended) Save the HuggingFace checkpoint directly by passing `save_hf=True` to `HFPolicy`'s `save_checkpoint`: + + ```python + policy.save_checkpoint( + weights_path=, + optimizer_path=, + save_torch_dist=True, + save_hf=True, + ) + ``` +2. Convert a torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose. + + ```python + uv run examples/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= \ No newline at end of file diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 934e33d4b7..4689e6d793 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -613,6 +613,12 @@ def grpo_train( and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed policy.prepare_for_training() + + is_last_checkpoint = ( + master_config["sft"]["max_num_steps"] - (step + 1) + < master_config["checkpointing"]["save_period"] + ) + grpo_save_state["step"] = step + 1 grpo_save_state["val_reward"] = val_metrics["accuracy"] grpo_save_state["consumed_samples"] = consumed_samples @@ -626,6 +632,7 @@ def grpo_train( optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" ), + save_hf=is_last_checkpoint, ) torch.save( dataloader.state_dict(), diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 40eaca3f87..0d8078890a 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -311,9 +311,7 @@ def sft_train( sft_save_state = _default_sft_save_state() step = 0 else: - step = ( - sft_save_state["step"] + 1 - ) # N+1 because the checkpoint is _after_ SFT iteration N + step = sft_save_state["step"] sft_config = master_config["sft"] # Validation configuration @@ -399,18 +397,25 @@ def sft_train( master_config["checkpointing"]["enabled"] and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed - sft_save_state["step"] = step + is_last_checkpoint = ( + master_config["sft"]["max_num_steps"] - (step + 1) + < master_config["checkpointing"]["save_period"] + ) + + sft_save_state["step"] = step + 1 sft_save_state["val_loss"] = val_metrics["val_loss"] with timer.time("checkpointing"): print(f"Saving checkpoint for step {step + 1}...") checkpoint_path = checkpointer.init_tmp_checkpoint( step + 1, sft_save_state, master_config ) + policy.save_checkpoint( weights_path=os.path.join(checkpoint_path, "policy", "weights"), optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" ), + save_hf=is_last_checkpoint, ) torch.save( train_dataloader.state_dict(), From 874d4462720164e0c90c72cc05b9badc7400d7ae Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 09:39:06 -0700 Subject: [PATCH 17/33] formatting Signed-off-by: ashors1 --- docs/design_docs/checkpointing.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/design_docs/checkpointing.md b/docs/design_docs/checkpointing.md index 4daf433f08..cc3f35a0fc 100644 --- a/docs/design_docs/checkpointing.md +++ b/docs/design_docs/checkpointing.md @@ -18,4 +18,5 @@ There are two ways to get a Reinforcer checkpoint in HuggingFace format. 2. Convert a torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose. ```python - uv run examples/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= \ No newline at end of file + uv run examples/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= + ``` \ No newline at end of file From a422db41d9681df218c9acb180f40fc3436ba1d4 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 09:40:32 -0700 Subject: [PATCH 18/33] capitalization Signed-off-by: ashors1 --- docs/design_docs/checkpointing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/design_docs/checkpointing.md b/docs/design_docs/checkpointing.md index cc3f35a0fc..9b9a6f6826 100644 --- a/docs/design_docs/checkpointing.md +++ b/docs/design_docs/checkpointing.md @@ -15,7 +15,7 @@ There are two ways to get a Reinforcer checkpoint in HuggingFace format. save_hf=True, ) ``` -2. Convert a torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose. +2. Convert a Torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose. ```python uv run examples/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= From 848b654d7db710b8961580c3eb2a1f80be871486 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 13:57:42 -0700 Subject: [PATCH 19/33] add hf checkpointing unit tests Signed-off-by: ashors1 --- tests/unit/utils/test_hf_checkpoint.py | 163 +++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100755 tests/unit/utils/test_hf_checkpoint.py diff --git a/tests/unit/utils/test_hf_checkpoint.py b/tests/unit/utils/test_hf_checkpoint.py new file mode 100755 index 0000000000..438fd5a932 --- /dev/null +++ b/tests/unit/utils/test_hf_checkpoint.py @@ -0,0 +1,163 @@ +import copy +import os +import pytest +import torch + +from nemo_reinforcer.utils.hf_checkpoint import ( + load_checkpoint, + save_checkpoint, + ModelState, + OptimizerState, +) + + +@pytest.fixture +def test_experiment(): + model = torch.nn.ModuleList( + [ + torch.nn.Linear(4, 4), + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ).to("cuda") + + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + + return model, optimizer, scheduler + + +## recursively get the dummy state dict +## by replacing tensors with random ones of the same shape +def get_dummy_state_dict(state_dict, dummy_dict={}): + for k in state_dict.keys(): + if isinstance(state_dict[k], dict): + dummy_dict[k] = get_dummy_state_dict(state_dict[k], {}) + elif isinstance(state_dict[k], torch.Tensor): + dummy_dict[k] = torch.randn(state_dict[k].shape) + else: + dummy_dict[k] = state_dict[k] + return dummy_dict + + +## recursively check equality of two dictionaries +def check_dict_equality(dict1, dict2): + for k in dict1.keys(): + if isinstance(dict1[k], dict): + check_dict_equality(dict1[k], dict2[k]) + elif isinstance(dict1[k], torch.Tensor): + assert torch.allclose(dict1[k], dict2[k]) + else: + assert dict1[k] == dict2[k] + + +def test_model_state(test_experiment): + test_model, _, _ = test_experiment + model_state = ModelState(test_model) + state_dict = model_state.state_dict() + assert set(state_dict) == {"model"} + + ## relu has no parameters + expected_keys = { + "0.bias", + "0.weight", + "1.bias", + "1.weight", + "3.bias", + "3.weight", + } + assert set(state_dict["model"].keys()) == expected_keys + + dummy_model_state_dict = get_dummy_state_dict(state_dict, {}) + + ## update the model's state dict and verify that the model's parameters are updated + model_state.load_state_dict(dummy_model_state_dict) + new_model_state_dict = model_state.state_dict() + check_dict_equality(new_model_state_dict, dummy_model_state_dict) + + +def test_optimizer_state(test_experiment): + test_model, optimizer, scheduler = test_experiment + + optim_state = OptimizerState(test_model, optimizer, scheduler) + state_dict = optim_state.state_dict() + + assert set(state_dict.keys()) == {"optim", "sched"} + + ## relu has no parameters + expected_keys = { + "0.bias", + "0.weight", + "1.bias", + "1.weight", + "3.bias", + "3.weight", + } + + assert set(state_dict["optim"]["state"].keys()) == expected_keys + + dummy_state_dict = get_dummy_state_dict(state_dict, {}) + + optim_state.load_state_dict(dummy_state_dict) + new_state_dict = optim_state.state_dict() + check_dict_equality(new_state_dict, dummy_state_dict) + + +def test_save_and_load_model_only(test_experiment): + test_model, _, _ = test_experiment + save_checkpoint(test_model, "/tmp/test_model_only") + assert os.path.exists("/tmp/test_model_only") + assert not os.path.exists("/tmp/test_model_only-hf") + assert os.listdir("/tmp/test_model_only") == [".metadata", "__0_0.distcp"] + + +def test_save_and_load_model_and_optimizer(test_experiment): + test_model, optimizer, scheduler = test_experiment + for _ in range(5): + scheduler.step() + + save_checkpoint( + test_model, + "/tmp/model_and_optimizer/model", + optimizer, + scheduler, + optimizer_path="/tmp/model_and_optimizer/optimizer", + ) + + assert os.path.exists("/tmp/model_and_optimizer/model") + assert os.path.exists("/tmp/model_and_optimizer/optimizer") + assert os.listdir("/tmp/model_and_optimizer/model") == [".metadata", "__0_0.distcp"] + assert os.listdir("/tmp/model_and_optimizer/optimizer") == [ + ".metadata", + "__0_0.distcp", + ] + + ## modify the model, optimizer, and scheduler and verify that loading the checkpoint overrides the values + new_linear = torch.nn.Linear(4, 4) + new_linear.weight = torch.nn.Parameter(torch.ones([4, 4]).to("cuda")) + new_linear.bias = torch.nn.Parameter(torch.ones(4).to("cuda")) + new_model = torch.nn.ModuleList( + [ + new_linear, + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ).to("cuda") + + new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR( + new_optimizer, step_size=4, gamma=0.2 + ) + load_checkpoint( + new_model, + "/tmp/model_and_optimizer/model", + new_optimizer, + new_scheduler, + optimizer_path="/tmp/model_and_optimizer/optimizer", + ) + + assert scheduler.state_dict() == new_scheduler.state_dict() + check_dict_equality(new_model.state_dict(), test_model.state_dict()) + check_dict_equality(new_optimizer.state_dict(), optimizer.state_dict()) From 8ca4dbdffc4cbca15cc83b9ce80ade08041ab55f Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 13:59:38 -0700 Subject: [PATCH 20/33] add copyright Signed-off-by: ashors1 --- tests/unit/utils/test_hf_checkpoint.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/unit/utils/test_hf_checkpoint.py b/tests/unit/utils/test_hf_checkpoint.py index 438fd5a932..64a72d2059 100755 --- a/tests/unit/utils/test_hf_checkpoint.py +++ b/tests/unit/utils/test_hf_checkpoint.py @@ -1,3 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import os import pytest From 163c6c9e9c9c8a757010f0fb3efe94fe5594aae4 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 14:08:15 -0700 Subject: [PATCH 21/33] linting Signed-off-by: ashors1 --- nemo_reinforcer/utils/hf_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 3f789e62d2..98a077fae1 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -154,7 +154,9 @@ def save_checkpoint( if save_hf: # Create a new path by appending "-hf" to the weights path hf_weights_path = f"{Path(weights_path)}-hf" - model.save_pretrained(hf_weights_path, is_main_process=(torch.distributed.get_rank()==0)) + model.save_pretrained( + hf_weights_path, is_main_process=(torch.distributed.get_rank() == 0) + ) if save_torch_dist: model_state_dict = {"model": ModelState(model)} From 83785ecc6bd94a225994ff511f4ffb7a93199742 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 14:48:41 -0700 Subject: [PATCH 22/33] fix tests Signed-off-by: ashors1 --- tests/unit/models/generation/test_vllm_generation.py | 1 + tests/unit/models/policy/test_hf_ray_policy.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 8c810e31dc..e8ff245de6 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -27,6 +27,7 @@ basic_vllm_test_config: VllmConfig = { "backend": "vllm", "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing + "tokenizer_name": "meta-llama/Llama-3.2-1B", "dtype": "bfloat16", "max_new_tokens": 10, "temperature": 1.0, diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index ded244feac..1c7f85f284 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -27,6 +27,7 @@ basic_llama_test_config: PolicyConfig = { "model_name": "meta-llama/Llama-3.2-1B", + "tokenizer_name": "meta-llama/Llama-3.2-1B", "generation_batch_size": 1, # Small batch size for testing "train_global_batch_size": 4, "train_micro_batch_size": 1, From 7eee5689e14948d9a6555cb0d3b5b7db60435461 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 17:49:12 -0700 Subject: [PATCH 23/33] add HF checkpoint save test Signed-off-by: ashors1 --- tests/unit/utils/test_hf_checkpoint.py | 109 +++++++++++++++++++++++-- 1 file changed, 103 insertions(+), 6 deletions(-) diff --git a/tests/unit/utils/test_hf_checkpoint.py b/tests/unit/utils/test_hf_checkpoint.py index 64a72d2059..23c912a8ce 100755 --- a/tests/unit/utils/test_hf_checkpoint.py +++ b/tests/unit/utils/test_hf_checkpoint.py @@ -16,6 +16,10 @@ import pytest import torch +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from transformers import AutoTokenizer, AutoModelForCausalLM from nemo_reinforcer.utils.hf_checkpoint import ( load_checkpoint, save_checkpoint, @@ -23,6 +27,17 @@ OptimizerState, ) +# Define basic test config +simple_policy_config = { + "model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "train_global_batch_size": 32, + "train_micro_batch_size": 1, + "logprob_batch_size": 1, + "max_total_sequence_length": 1024, + "precision": "float32", +} + @pytest.fixture def test_experiment(): @@ -41,6 +56,43 @@ def test_experiment(): return model, optimizer, scheduler +## TODO: check scope +@pytest.fixture(scope="module") +def cluster(): + """Create a virtual cluster for testing.""" + # Create a cluster with 1 GPU + virtual_cluster = RayVirtualCluster( + bundle_ct_per_node_list=[1], # [2], # 1 node with 2 GPU bundle + use_gpus=True, + max_colocated_worker_groups=1, # 2, + num_gpus_per_node=1, # 2, # Use available GPUs + name="test-cluster", + ) + yield virtual_cluster + virtual_cluster.shutdown() + + +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + tokenizer_name = simple_policy_config["tokenizer_name"] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +@pytest.fixture(scope="function") +def policy(cluster, tokenizer): + """Initialize the policy.""" + return HfPolicy( + cluster=cluster, + config=simple_policy_config, + init_optimizer=False, + init_reference_model=False, + ) + + ## recursively get the dummy state dict ## by replacing tensors with random ones of the same shape def get_dummy_state_dict(state_dict, dummy_dict={}): @@ -122,7 +174,7 @@ def test_save_and_load_model_only(test_experiment): save_checkpoint(test_model, "/tmp/test_model_only") assert os.path.exists("/tmp/test_model_only") assert not os.path.exists("/tmp/test_model_only-hf") - assert os.listdir("/tmp/test_model_only") == [".metadata", "__0_0.distcp"] + assert set(os.listdir("/tmp/test_model_only")) == {".metadata", "__0_0.distcp"} def test_save_and_load_model_and_optimizer(test_experiment): @@ -138,13 +190,14 @@ def test_save_and_load_model_and_optimizer(test_experiment): optimizer_path="/tmp/model_and_optimizer/optimizer", ) - assert os.path.exists("/tmp/model_and_optimizer/model") - assert os.path.exists("/tmp/model_and_optimizer/optimizer") - assert os.listdir("/tmp/model_and_optimizer/model") == [".metadata", "__0_0.distcp"] - assert os.listdir("/tmp/model_and_optimizer/optimizer") == [ + assert set(os.listdir("/tmp/model_and_optimizer/model")) == { + ".metadata", + "__0_0.distcp", + } + assert set(os.listdir("/tmp/model_and_optimizer/optimizer")) == { ".metadata", "__0_0.distcp", - ] + } ## modify the model, optimizer, and scheduler and verify that loading the checkpoint overrides the values new_linear = torch.nn.Linear(4, 4) @@ -174,3 +227,47 @@ def test_save_and_load_model_and_optimizer(test_experiment): assert scheduler.state_dict() == new_scheduler.state_dict() check_dict_equality(new_model.state_dict(), test_model.state_dict()) check_dict_equality(new_optimizer.state_dict(), optimizer.state_dict()) + + +def test_save_and_load_hf_checkpoint(policy): + ## warm up with a forward pass + ## this is needed before saving a checkpoint because FSDP does some lazy initialization + input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128 + attention_mask = torch.ones(4, 128) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + dummy_fwd_dict = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "labels": torch.randint(0, 16000, (4, 128)), + } + ) + policy.get_logprobs(dummy_fwd_dict) + + policy.save_checkpoint( + "/tmp/test_hf_and_dcp", + save_hf=True, + save_torch_dist=True, + ) + + ## make sure we save both HF and DCP checkpoints + assert set(os.listdir("/tmp/test_hf_and_dcp")) == {"__0_0.distcp", ".metadata"} + ## 1B model has two shards + assert set(os.listdir("/tmp/test_hf_and_dcp-hf")) == { + "config.json", + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "model.safetensors.index.json", + } + + coverted_model = AutoModelForCausalLM.from_pretrained("/tmp/test_hf_and_dcp-hf") + original_model = AutoModelForCausalLM.from_pretrained( + simple_policy_config["model_name"] + ) + + ## make sure this model matches the original + check_dict_equality(coverted_model.state_dict(), original_model.state_dict()) + + policy.worker_group.shutdown() From 1cb948325f30c5cf5d5f1b6e99664f6c41568251 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 2 Apr 2025 21:50:11 -0700 Subject: [PATCH 24/33] fix tests and improve docstrings Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/grpo.py | 2 +- nemo_reinforcer/utils/hf_checkpoint.py | 4 +- .../models/generation/test_vllm_generation.py | 2 + tests/unit/utils/test_hf_checkpoint.py | 180 ++++++++++-------- 4 files changed, 103 insertions(+), 85 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 4689e6d793..7daa762487 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -615,7 +615,7 @@ def grpo_train( policy.prepare_for_training() is_last_checkpoint = ( - master_config["sft"]["max_num_steps"] - (step + 1) + master_config["grpo"]["max_num_steps"] - (step + 1) < master_config["checkpointing"]["save_period"] ) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 98a077fae1..69e7e2d329 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -180,10 +180,10 @@ def load_checkpoint( scheduler: Optional[Any] = None, optimizer_path: Optional[str] = None, ) -> None: - """Load a checkpoint into the model and optionally optimizer. + """Load a model weights and optionally optimizer state. Args: - model: The PyTorch model to load into + model: The PyTorch model whose weights to update weights_path: Path to load model weights from optimizer: Optional optimizer to load state into scheduler: Optional scheduler to load state into diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index e8ff245de6..52878dbb48 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -218,6 +218,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create HF-specific config with required parameters hf_config = { "model_name": basic_vllm_test_config["model_name"], + "tokenizer_name": basic_vllm_test_config["tokenizer_name"], # Required training parameters "train_global_batch_size": 4, "train_micro_batch_size": 1, @@ -474,6 +475,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): # Create HF-specific config with required parameters hf_config = { "model_name": basic_vllm_test_config["model_name"], + "tokenizer_name": basic_vllm_test_config["tokenizer_name"], # Required training parameters "train_global_batch_size": 4, "train_micro_batch_size": 1, diff --git a/tests/unit/utils/test_hf_checkpoint.py b/tests/unit/utils/test_hf_checkpoint.py index 23c912a8ce..d6c085157a 100755 --- a/tests/unit/utils/test_hf_checkpoint.py +++ b/tests/unit/utils/test_hf_checkpoint.py @@ -15,6 +15,7 @@ import os import pytest import torch +from tempfile import TemporaryDirectory from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster @@ -40,7 +41,7 @@ @pytest.fixture -def test_experiment(): +def mock_experiment(): model = torch.nn.ModuleList( [ torch.nn.Linear(4, 4), @@ -56,16 +57,15 @@ def test_experiment(): return model, optimizer, scheduler -## TODO: check scope @pytest.fixture(scope="module") def cluster(): """Create a virtual cluster for testing.""" - # Create a cluster with 1 GPU + # Create a cluster with 2 GPU virtual_cluster = RayVirtualCluster( - bundle_ct_per_node_list=[1], # [2], # 1 node with 2 GPU bundle + bundle_ct_per_node_list=[2], # 1 node with 2 GPU bundle use_gpus=True, - max_colocated_worker_groups=1, # 2, - num_gpus_per_node=1, # 2, # Use available GPUs + max_colocated_worker_groups=1, + num_gpus_per_node=2, # Use available GPUs name="test-cluster", ) yield virtual_cluster @@ -93,9 +93,10 @@ def policy(cluster, tokenizer): ) -## recursively get the dummy state dict -## by replacing tensors with random ones of the same shape def get_dummy_state_dict(state_dict, dummy_dict={}): + """Recursively get the dummy state dict + by replacing tensors with random ones of the same shape. + """ for k in state_dict.keys(): if isinstance(state_dict[k], dict): dummy_dict[k] = get_dummy_state_dict(state_dict[k], {}) @@ -106,8 +107,8 @@ def get_dummy_state_dict(state_dict, dummy_dict={}): return dummy_dict -## recursively check equality of two dictionaries def check_dict_equality(dict1, dict2): + """Recursively check equality of two dictionaries""" for k in dict1.keys(): if isinstance(dict1[k], dict): check_dict_equality(dict1[k], dict2[k]) @@ -117,8 +118,8 @@ def check_dict_equality(dict1, dict2): assert dict1[k] == dict2[k] -def test_model_state(test_experiment): - test_model, _, _ = test_experiment +def test_model_state(mock_experiment): + test_model, _, _ = mock_experiment model_state = ModelState(test_model) state_dict = model_state.state_dict() assert set(state_dict) == {"model"} @@ -142,8 +143,8 @@ def test_model_state(test_experiment): check_dict_equality(new_model_state_dict, dummy_model_state_dict) -def test_optimizer_state(test_experiment): - test_model, optimizer, scheduler = test_experiment +def test_optimizer_state(mock_experiment): + test_model, optimizer, scheduler = mock_experiment optim_state = OptimizerState(test_model, optimizer, scheduler) state_dict = optim_state.state_dict() @@ -169,60 +170,68 @@ def test_optimizer_state(test_experiment): check_dict_equality(new_state_dict, dummy_state_dict) -def test_save_and_load_model_only(test_experiment): - test_model, _, _ = test_experiment - save_checkpoint(test_model, "/tmp/test_model_only") - assert os.path.exists("/tmp/test_model_only") - assert not os.path.exists("/tmp/test_model_only-hf") - assert set(os.listdir("/tmp/test_model_only")) == {".metadata", "__0_0.distcp"} +def test_save_and_load_model_only(mock_experiment): + test_model, _, _ = mock_experiment + + with TemporaryDirectory() as tmp_dir: + save_checkpoint(test_model, os.path.join(tmp_dir, "test_model_only")) + assert os.path.exists(os.path.join(tmp_dir, "test_model_only")) + assert not os.path.exists(os.path.join(tmp_dir, "test_model_only-hf")) + assert set(os.listdir(os.path.join(tmp_dir, "test_model_only"))) == { + ".metadata", + "__0_0.distcp", + } -def test_save_and_load_model_and_optimizer(test_experiment): - test_model, optimizer, scheduler = test_experiment +def test_save_and_load_model_and_optimizer(mock_experiment): + test_model, optimizer, scheduler = mock_experiment for _ in range(5): scheduler.step() - save_checkpoint( - test_model, - "/tmp/model_and_optimizer/model", - optimizer, - scheduler, - optimizer_path="/tmp/model_and_optimizer/optimizer", - ) - - assert set(os.listdir("/tmp/model_and_optimizer/model")) == { - ".metadata", - "__0_0.distcp", - } - assert set(os.listdir("/tmp/model_and_optimizer/optimizer")) == { - ".metadata", - "__0_0.distcp", - } - - ## modify the model, optimizer, and scheduler and verify that loading the checkpoint overrides the values - new_linear = torch.nn.Linear(4, 4) - new_linear.weight = torch.nn.Parameter(torch.ones([4, 4]).to("cuda")) - new_linear.bias = torch.nn.Parameter(torch.ones(4).to("cuda")) - new_model = torch.nn.ModuleList( - [ - new_linear, - torch.nn.LayerNorm(4), - torch.nn.ReLU(), - torch.nn.Linear(4, 1), - ] - ).to("cuda") + with TemporaryDirectory() as tmp_dir: + save_checkpoint( + test_model, + os.path.join(tmp_dir, "model_and_optimizer/model"), + optimizer, + scheduler, + optimizer_path=os.path.join(tmp_dir, "model_and_optimizer/optimizer"), + ) + + assert set(os.listdir(os.path.join(tmp_dir, "model_and_optimizer/model"))) == { + ".metadata", + "__0_0.distcp", + } + assert set( + os.listdir(os.path.join(tmp_dir, "model_and_optimizer/optimizer")) + ) == { + ".metadata", + "__0_0.distcp", + } - new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001) - new_scheduler = torch.optim.lr_scheduler.StepLR( - new_optimizer, step_size=4, gamma=0.2 - ) - load_checkpoint( - new_model, - "/tmp/model_and_optimizer/model", - new_optimizer, - new_scheduler, - optimizer_path="/tmp/model_and_optimizer/optimizer", - ) + ## modify the model, optimizer, and scheduler and verify that loading the checkpoint overrides the values + new_linear = torch.nn.Linear(4, 4) + new_linear.weight = torch.nn.Parameter(torch.ones([4, 4]).to("cuda")) + new_linear.bias = torch.nn.Parameter(torch.ones(4).to("cuda")) + new_model = torch.nn.ModuleList( + [ + new_linear, + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ).to("cuda") + + new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR( + new_optimizer, step_size=4, gamma=0.2 + ) + load_checkpoint( + new_model, + os.path.join(tmp_dir, "model_and_optimizer/model"), + new_optimizer, + new_scheduler, + optimizer_path=os.path.join(tmp_dir, "model_and_optimizer/optimizer"), + ) assert scheduler.state_dict() == new_scheduler.state_dict() check_dict_equality(new_model.state_dict(), test_model.state_dict()) @@ -245,29 +254,36 @@ def test_save_and_load_hf_checkpoint(policy): ) policy.get_logprobs(dummy_fwd_dict) - policy.save_checkpoint( - "/tmp/test_hf_and_dcp", - save_hf=True, - save_torch_dist=True, - ) - - ## make sure we save both HF and DCP checkpoints - assert set(os.listdir("/tmp/test_hf_and_dcp")) == {"__0_0.distcp", ".metadata"} - ## 1B model has two shards - assert set(os.listdir("/tmp/test_hf_and_dcp-hf")) == { - "config.json", - "generation_config.json", - "model-00001-of-00002.safetensors", - "model-00002-of-00002.safetensors", - "model.safetensors.index.json", - } + with TemporaryDirectory() as tmp_dir: + policy.save_checkpoint( + os.path.join(tmp_dir, "test_hf_and_dcp"), + save_hf=True, + save_torch_dist=True, + ) + + ## make sure we save both HF and DCP checkpoints + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == { + "__0_0.distcp", + "__1_0.distcp", + ".metadata", + } + ## 1B model has two shards + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp-hf"))) == { + "config.json", + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "model.safetensors.index.json", + } - coverted_model = AutoModelForCausalLM.from_pretrained("/tmp/test_hf_and_dcp-hf") - original_model = AutoModelForCausalLM.from_pretrained( - simple_policy_config["model_name"] - ) + coverted_model = AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_dir, "test_hf_and_dcp-hf") + ) + original_model = AutoModelForCausalLM.from_pretrained( + simple_policy_config["model_name"] + ) - ## make sure this model matches the original + ## make sure converted model matches the original check_dict_equality(coverted_model.state_dict(), original_model.state_dict()) policy.worker_group.shutdown() From 3affdfd276bebce4c89afa400e1a5268ecb96056 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 3 Apr 2025 10:48:29 -0700 Subject: [PATCH 25/33] address comments Signed-off-by: ashors1 --- docs/index.md | 1 + examples/convert_dcp_to_hf.py | 29 +++++++++-------------- nemo_reinforcer/models/policy/__init__.py | 1 + 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/docs/index.md b/docs/index.md index 0628f19953..553778ff98 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,4 +47,5 @@ design_docs/logger.md design_docs/uv.md design_docs/chat_datasets.md design_docs/generation.md +design_docs/checkpointing.md ``` diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py index 9c2b793f30..ee347eeb9e 100644 --- a/examples/convert_dcp_to_hf.py +++ b/examples/convert_dcp_to_hf.py @@ -14,7 +14,7 @@ import argparse import os -from omegaconf import OmegaConf +import json from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster from nemo_reinforcer.models.policy.hf_policy import HfPolicy @@ -30,7 +30,7 @@ def parse_args(): "--config", type=str, default=None, - help="Path to YAML config file used during model training", + help="Path to config.json file in the checkpoint directory", ) parser.add_argument( "--dcp-ckpt-path", type=str, default=None, help="Path to DCP checkpoint" @@ -39,28 +39,17 @@ def parse_args(): "--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint" ) # Parse known args for the script - args, remaining = parser.parse_known_args() + args = parser.parse_args() - # Convert remaining args to OmegaConf format - overrides = OmegaConf.from_dotlist(remaining) - - return args, overrides + return args def main(): """Main entry point.""" - # Parse arguments - args, overrides = parse_args() - - if not args.config: - args.config = os.path.join(os.path.dirname(__file__), "configs", "sft.yaml") - - config = load_config(args.config) - print(f"Loaded configuration from: {args.config}") + args = parse_args() - if overrides: - print(f"Overrides: {overrides}") - config = OmegaConf.merge(config, overrides) + with open(args.config, "r") as f: + config = json.load(f) dcp_ckpt = args.dcp_ckpt_path hf_ckpt = args.hf_ckpt_path @@ -92,8 +81,12 @@ def main(): save_hf=True, save_torch_dist=False, ) + print(f"Saved HF checkpoint to: {hf_ckpt}-hf") + cluster.shutdown() + policy.worker_group.shutdown() + if __name__ == "__main__": main() diff --git a/nemo_reinforcer/models/policy/__init__.py b/nemo_reinforcer/models/policy/__init__.py index ee2bf2389e..24390b9670 100644 --- a/nemo_reinforcer/models/policy/__init__.py +++ b/nemo_reinforcer/models/policy/__init__.py @@ -19,6 +19,7 @@ class PolicyConfig(TypedDict): model_name: str + tokenizer_name: str train_global_batch_size: int train_micro_batch_size: int learning_rate: float From 174ef813eb30de0de8bfab3f81f1972acdcaaac3 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 3 Apr 2025 11:20:48 -0700 Subject: [PATCH 26/33] cleanup Signed-off-by: ashors1 --- nemo_reinforcer/utils/hf_checkpoint.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 69e7e2d329..3be392e61c 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -56,9 +56,7 @@ def state_dict(self): cpu_offload=True ), ) - return { - "model": model_state_dict, - } + return model_state_dict def load_state_dict(self, state_dict): """Load the state dictionary into the model. @@ -69,7 +67,7 @@ def load_state_dict(self, state_dict): # sets our state dicts on the model, now that we've loaded set_model_state_dict( self.model, - state_dict["model"], + state_dict, ) @@ -154,23 +152,29 @@ def save_checkpoint( if save_hf: # Create a new path by appending "-hf" to the weights path hf_weights_path = f"{Path(weights_path)}-hf" + + ## make sure we save the checkpoint from rank 0 only + def custom_save(obj, f) -> None: + if torch.distributed.get_rank == 0: + torch.save(obj, f) + model.save_pretrained( - hf_weights_path, is_main_process=(torch.distributed.get_rank() == 0) + hf_weights_path, + is_main_process=(torch.distributed.get_rank() == 0), + save_function=custom_save, ) if save_torch_dist: - model_state_dict = {"model": ModelState(model)} - dcp.save(model_state_dict, checkpoint_id=weights_path) + model_state = {"model": ModelState(model)} + dcp.save(model_state, checkpoint_id=weights_path) if optimizer is not None: if optimizer_path is None: raise ValueError( "optimizer_path must be provided when saving optimizer state" ) - optimizer_state_dict = { - "optim": OptimizerState(model, optimizer, scheduler) - } - dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path) + optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)} + dcp.save(optimizer_state, checkpoint_id=optimizer_path) def load_checkpoint( From f6e56d01ba8833cb09d91df713fa09eb9feb191d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 3 Apr 2025 13:06:34 -0700 Subject: [PATCH 27/33] make HF save multi-process safe Signed-off-by: ashors1 --- nemo_reinforcer/utils/hf_checkpoint.py | 30 +++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index 3be392e61c..dec4b149bc 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -150,19 +150,23 @@ def save_checkpoint( save_hf: Whether to save in HuggingFace format """ if save_hf: - # Create a new path by appending "-hf" to the weights path - hf_weights_path = f"{Path(weights_path)}-hf" - - ## make sure we save the checkpoint from rank 0 only - def custom_save(obj, f) -> None: - if torch.distributed.get_rank == 0: - torch.save(obj, f) - - model.save_pretrained( - hf_weights_path, - is_main_process=(torch.distributed.get_rank() == 0), - save_function=custom_save, - ) + ## NOTE rank0_only is False because True causes a hanf with SFT + ## this causes multiple copies of the model to get offloaded to CPU + with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params( + model, + offload_to_cpu=True, + writeback=False, + ): + state_dict = model.state_dict() + + if torch.distributed.get_rank() == 0: + # Create a new path by appending "-hf" to the weights path + hf_weights_path = f"{Path(weights_path)}-hf" + + model.save_pretrained( + hf_weights_path, + state_dict=state_dict, + ) if save_torch_dist: model_state = {"model": ModelState(model)} From 58a0d3d07e2e070c741feb964fd3ce183b7b7a4f Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 4 Apr 2025 09:56:34 -0700 Subject: [PATCH 28/33] hf checkpointing fix Signed-off-by: ashors1 --- nemo_reinforcer/utils/hf_checkpoint.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/hf_checkpoint.py index dec4b149bc..6f22ea82fd 100644 --- a/nemo_reinforcer/utils/hf_checkpoint.py +++ b/nemo_reinforcer/utils/hf_checkpoint.py @@ -150,14 +150,7 @@ def save_checkpoint( save_hf: Whether to save in HuggingFace format """ if save_hf: - ## NOTE rank0_only is False because True causes a hanf with SFT - ## this causes multiple copies of the model to get offloaded to CPU - with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params( - model, - offload_to_cpu=True, - writeback=False, - ): - state_dict = model.state_dict() + model_state_dict = model._fsdp_wrapped_module.state_dict() if torch.distributed.get_rank() == 0: # Create a new path by appending "-hf" to the weights path @@ -165,7 +158,7 @@ def save_checkpoint( model.save_pretrained( hf_weights_path, - state_dict=state_dict, + state_dict=model_state_dict, ) if save_torch_dist: From fe3f47c83b5065b44a86aa0d96c3fb43afaae607 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 4 Apr 2025 15:14:19 -0700 Subject: [PATCH 29/33] enable checkpointing in sft functional test Signed-off-by: ashors1 --- tests/functional/sft.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index 82d263c9da..ac14216adf 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -26,10 +26,15 @@ python -u $PROJECT_ROOT/examples/run_sft.py \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ - checkpointing.enabled=false \ + checkpointing.enabled=true \ + checkpointing.save_every_n_steps=10 \ + checkpointing.checkpoint_dir=/tmp/sft_checkpoints \ $@ \ 2>&1 | tee $RUN_LOG +## clean up checkpoint directory +rm -r /tmp/sft_checkpoints + cd $SCRIPT_DIR python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS From 4f174313ce5620da3de1290ea9538d26ba763337 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 7 Apr 2025 09:04:52 -0700 Subject: [PATCH 30/33] update unit test Signed-off-by: ashors1 --- tests/unit/utils/test_hf_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/utils/test_hf_checkpoint.py b/tests/unit/utils/test_hf_checkpoint.py index d6c085157a..f662f273b0 100755 --- a/tests/unit/utils/test_hf_checkpoint.py +++ b/tests/unit/utils/test_hf_checkpoint.py @@ -122,7 +122,6 @@ def test_model_state(mock_experiment): test_model, _, _ = mock_experiment model_state = ModelState(test_model) state_dict = model_state.state_dict() - assert set(state_dict) == {"model"} ## relu has no parameters expected_keys = { @@ -133,7 +132,7 @@ def test_model_state(mock_experiment): "3.bias", "3.weight", } - assert set(state_dict["model"].keys()) == expected_keys + assert set(state_dict.keys()) == expected_keys dummy_model_state_dict = get_dummy_state_dict(state_dict, {}) From 4e11a2e2283bb101e02b7258350984715b125837 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 7 Apr 2025 11:25:29 -0700 Subject: [PATCH 31/33] address comments Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/grpo.py | 3 ++- nemo_reinforcer/algorithms/sft.py | 3 ++- nemo_reinforcer/models/policy/hf_policy.py | 23 ++++++++++++++++++- ...{hf_checkpoint.py => native_checkpoint.py} | 0 tests/functional/sft.sh | 2 +- ...heckpoint.py => test_native_checkpoint.py} | 2 +- 6 files changed, 28 insertions(+), 5 deletions(-) rename nemo_reinforcer/utils/{hf_checkpoint.py => native_checkpoint.py} (100%) rename tests/unit/utils/{test_hf_checkpoint.py => test_native_checkpoint.py} (99%) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 11fd149504..841fc91552 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -615,7 +615,8 @@ def grpo_train( policy.prepare_for_training() is_last_checkpoint = ( - master_config["grpo"]["max_num_steps"] - (step + 1) + min(len(dataloader), master_config["grpo"]["max_num_steps"]) + - (step + 1) < master_config["checkpointing"]["save_period"] ) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index e5e684fbbc..b5bb41aec5 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -398,7 +398,8 @@ def sft_train( and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed is_last_checkpoint = ( - master_config["sft"]["max_num_steps"] - (step + 1) + min(len(train_dataloader), master_config["sft"]["max_num_steps"]) + - (step + 1) < master_config["checkpointing"]["save_period"] ) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 88c2928be7..fb459a75e6 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -820,7 +820,28 @@ def save_checkpoint( save_torch_dist: bool = True, save_hf: bool = False, ): - """Save a checkpoint of the model.""" + """Save a checkpoint of the model. + + The checkpoint is saved in the following format: + + weights_path/ + __0_1.distcp + __1_0.distcp + ... + weights_path-hf/ + config.json + generation_config.json + model-00001-of-.safetensors + ... + model.safetensors.index.json + optimizer_path/ + __0_0.distcp + __1_0.distcp + ... + + the HuggingFace checkpoint is saved only if `save_hf` is True, + and the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + """ save_checkpoint( model=self.model, weights_path=weights_path, diff --git a/nemo_reinforcer/utils/hf_checkpoint.py b/nemo_reinforcer/utils/native_checkpoint.py similarity index 100% rename from nemo_reinforcer/utils/hf_checkpoint.py rename to nemo_reinforcer/utils/native_checkpoint.py diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index ac14216adf..ff3c3d6007 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -33,7 +33,7 @@ python -u $PROJECT_ROOT/examples/run_sft.py \ 2>&1 | tee $RUN_LOG ## clean up checkpoint directory -rm -r /tmp/sft_checkpoints +trap "rm -r /tmp/sft_checkpoints" EXIT cd $SCRIPT_DIR python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/unit/utils/test_hf_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py similarity index 99% rename from tests/unit/utils/test_hf_checkpoint.py rename to tests/unit/utils/test_native_checkpoint.py index f662f273b0..8f71badea1 100755 --- a/tests/unit/utils/test_hf_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -21,7 +21,7 @@ from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.models.policy.hf_policy import HfPolicy from transformers import AutoTokenizer, AutoModelForCausalLM -from nemo_reinforcer.utils.hf_checkpoint import ( +from nemo_reinforcer.utils.native_checkpoint import ( load_checkpoint, save_checkpoint, ModelState, From 7412223e8dee2566ed579e4f76bed14f9c885557 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 7 Apr 2025 11:36:28 -0700 Subject: [PATCH 32/33] small fixes Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 840d7d6255..051e56e23f 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -45,7 +45,7 @@ from nemo_reinforcer.distributed.virtual_cluster import ( PY_EXECUTABLES, ) -from nemo_reinforcer.utils.hf_checkpoint import ( +from nemo_reinforcer.utils.native_checkpoint import ( save_checkpoint, load_checkpoint, ) @@ -102,9 +102,6 @@ def __init__( else: self.reference_model = None self.tokenizer = get_tokenizer(tokenizer_name) - # If no pad token is defined, you might need: - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token # ------------------------------------------------ # 3) Move to GPU + Composable FSDP From 0749217065543241d3ea0612df75f5c8e88908da Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 7 Apr 2025 14:13:02 -0700 Subject: [PATCH 33/33] fix trap command Signed-off-by: ashors1 --- tests/functional/sft.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index ff3c3d6007..85282d64f4 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -1,5 +1,8 @@ #!/bin/bash +## clean up checkpoint directory on exit +trap "rm -rf /tmp/sft_checkpoints" EXIT + SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) # Mark the current repo as safe, since wandb fetchs metadata about the repo @@ -32,9 +35,6 @@ python -u $PROJECT_ROOT/examples/run_sft.py \ $@ \ 2>&1 | tee $RUN_LOG -## clean up checkpoint directory -trap "rm -r /tmp/sft_checkpoints" EXIT - cd $SCRIPT_DIR python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS