diff --git a/docs/design_docs/checkpointing.md b/docs/design_docs/checkpointing.md new file mode 100644 index 0000000000..9b9a6f6826 --- /dev/null +++ b/docs/design_docs/checkpointing.md @@ -0,0 +1,22 @@ +# Checkpointing with HuggingFace Models + +## Checkpoint Format +Reinforcer provides two checkpoint formats for HuggingFace models: Torch distributed and HuggingFace format. Torch distributed is used by default for efficiency, and HuggingFace format is provided for compatibility with HuggingFace's `AutoModel.from_pretrained` API. Note that HuggingFace format checkpoints save only the model weights, ignoring the optimizer states. It is recommended to use Torch distributed format to save intermediate checkpoints and to save a HuggingFace checkpoint only at the end of training. + +There are two ways to get a Reinforcer checkpoint in HuggingFace format. + +1. (Recommended) Save the HuggingFace checkpoint directly by passing `save_hf=True` to `HFPolicy`'s `save_checkpoint`: + + ```python + policy.save_checkpoint( + weights_path=, + optimizer_path=, + save_torch_dist=True, + save_hf=True, + ) + ``` +2. Convert a Torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose. + + ```python + uv run examples/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= + ``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 0628f19953..553778ff98 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,4 +47,5 @@ design_docs/logger.md design_docs/uv.md design_docs/chat_datasets.md design_docs/generation.md +design_docs/checkpointing.md ``` diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 422e869f56..3d8fdfce43 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -25,6 +25,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 4 generation_batch_size: 32 # Only used when generating using HF backend diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 261db927b1..f2e0576fbc 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -7,6 +7,7 @@ grpo: policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 1 generation_batch_size: 32 # Only used when generating using HF backend diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index e4b116a351..bb8467165f 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -18,6 +18,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 32 train_micro_batch_size: 1 max_total_sequence_length: 1024 diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py new file mode 100644 index 0000000000..ee347eeb9e --- /dev/null +++ b/examples/convert_dcp_to_hf.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import json + +from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.utils.config import load_config + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Convert Torch DCP checkpoint to HF checkpoint" + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to config.json file in the checkpoint directory", + ) + parser.add_argument( + "--dcp-ckpt-path", type=str, default=None, help="Path to DCP checkpoint" + ) + parser.add_argument( + "--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint" + ) + # Parse known args for the script + args = parser.parse_args() + + return args + + +def main(): + """Main entry point.""" + args = parse_args() + + with open(args.config, "r") as f: + config = json.load(f) + + dcp_ckpt = args.dcp_ckpt_path + hf_ckpt = args.hf_ckpt_path + + # Extract individual configs for easier access + policy_config = config["policy"] + cluster_config = config["cluster"] + + init_ray() + + cluster = RayVirtualCluster( + name="convert_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1, + ) + + policy = HfPolicy( + cluster=cluster, + config=policy_config, + weights_path=dcp_ckpt, + init_optimizer=False, + ) + + policy.save_checkpoint( + weights_path=os.path.abspath(hf_ckpt), + save_hf=True, + save_torch_dist=False, + ) + + print(f"Saved HF checkpoint to: {hf_ckpt}-hf") + + cluster.shutdown() + policy.worker_group.shutdown() + + +if __name__ == "__main__": + main() diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 6ea5a56dd6..0eda853375 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -236,10 +236,10 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - weights_path=Path(last_checkpoint_path) / "policy.pt" + weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, - optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" if last_checkpoint_path else None, init_optimizer=True, @@ -608,6 +608,13 @@ def grpo_train( and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed policy.prepare_for_training() + + is_last_checkpoint = ( + min(len(dataloader), master_config["grpo"]["max_num_steps"]) + - (step + 1) + < master_config["checkpointing"]["save_period"] + ) + grpo_save_state["step"] = step + 1 grpo_save_state["val_reward"] = val_metrics["accuracy"] grpo_save_state["consumed_samples"] = consumed_samples @@ -617,8 +624,11 @@ def grpo_train( step + 1, grpo_save_state, master_config ) policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), + weights_path=os.path.join(checkpoint_path, "policy", "weights"), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + save_hf=is_last_checkpoint, ) torch.save( dataloader.state_dict(), diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 8f9e34f9da..b5bb41aec5 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -175,10 +175,10 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - weights_path=Path(last_checkpoint_path) / "policy.pt" + weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, - optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" if last_checkpoint_path else None, init_optimizer=True, @@ -311,9 +311,7 @@ def sft_train( sft_save_state = _default_sft_save_state() step = 0 else: - step = ( - sft_save_state["step"] + 1 - ) # N+1 because the checkpoint is _after_ SFT iteration N + step = sft_save_state["step"] sft_config = master_config["sft"] # Validation configuration @@ -399,19 +397,26 @@ def sft_train( master_config["checkpointing"]["enabled"] and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed - sft_save_state["step"] = step + is_last_checkpoint = ( + min(len(train_dataloader), master_config["sft"]["max_num_steps"]) + - (step + 1) + < master_config["checkpointing"]["save_period"] + ) + + sft_save_state["step"] = step + 1 sft_save_state["val_loss"] = val_metrics["val_loss"] with timer.time("checkpointing"): print(f"Saving checkpoint for step {step + 1}...") checkpoint_path = checkpointer.init_tmp_checkpoint( step + 1, sft_save_state, master_config ) + policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), - ## NOTE: below is a workaround to avoid a bug with checkpointing - ## this should be removed once the bug is fixed - offload_to_cpu=False, + weights_path=os.path.join(checkpoint_path, "policy", "weights"), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + save_hf=is_last_checkpoint, ) torch.save( train_dataloader.state_dict(), diff --git a/nemo_reinforcer/models/policy/__init__.py b/nemo_reinforcer/models/policy/__init__.py index ee2bf2389e..24390b9670 100644 --- a/nemo_reinforcer/models/policy/__init__.py +++ b/nemo_reinforcer/models/policy/__init__.py @@ -19,6 +19,7 @@ class PolicyConfig(TypedDict): model_name: str + tokenizer_name: str train_global_batch_size: int train_micro_batch_size: int learning_rate: float diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index c36bc0fec7..051e56e23f 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -23,9 +23,7 @@ from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import ( FullyShardedDataParallel, - FullStateDictConfig, MixedPrecision, - StateDictType, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from transformers import AutoModelForCausalLM @@ -47,6 +45,10 @@ from nemo_reinforcer.distributed.virtual_cluster import ( PY_EXECUTABLES, ) +from nemo_reinforcer.utils.native_checkpoint import ( + save_checkpoint, + load_checkpoint, +) @ray.remote @@ -77,6 +79,7 @@ def __init__( rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] + tokenizer_name = self.cfg["tokenizer_name"] if self.cfg["precision"] == "float32": self.dtype = torch.float32 elif self.cfg["precision"] == "bfloat16": @@ -98,7 +101,7 @@ def __init__( ) else: self.reference_model = None - self.tokenizer = get_tokenizer(model_name) + self.tokenizer = get_tokenizer(tokenizer_name) # ------------------------------------------------ # 3) Move to GPU + Composable FSDP @@ -139,7 +142,7 @@ def do_fsdp(model): else: self.optimizer = None - if "scheduler" in self.cfg: + if "scheduler" in self.cfg and self.optimizer is not None: if isinstance(self.cfg["scheduler"], dict): scheduler_cls = import_class_from_path(self.cfg["scheduler"]["name"]) self.scheduler = scheduler_cls( @@ -165,7 +168,7 @@ def do_fsdp(model): self.optimizer, schedulers, milestones ) - else: + elif self.optimizer is not None: ## default to a passthrough LR schedule self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=lambda epoch: 1 @@ -173,7 +176,10 @@ def do_fsdp(model): # restore if weights_path: - self.load_checkpoint(weights_path, optimizer_path) + self.load_checkpoint( + weights_path, + optimizer_path, + ) else: print( "No weights path provided. Starting from scratch (default policy init)" @@ -817,78 +823,50 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, - offload_to_cpu: bool = True, + save_torch_dist: bool = True, + save_hf: bool = False, ): - # Config to save full state dict on rank 0, offloaded to CPU - state_dict_config = FullStateDictConfig( - offload_to_cpu=offload_to_cpu, rank0_only=True + """Save a checkpoint of the model. + + The checkpoint is saved in the following format: + + weights_path/ + __0_1.distcp + __1_0.distcp + ... + weights_path-hf/ + config.json + generation_config.json + model-00001-of-.safetensors + ... + model.safetensors.index.json + optimizer_path/ + __0_0.distcp + __1_0.distcp + ... + + the HuggingFace checkpoint is saved only if `save_hf` is True, + and the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + """ + save_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, + save_torch_dist=save_torch_dist, + save_hf=save_hf, ) - with FullyShardedDataParallel.state_dict_type( - self.model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=state_dict_config, - ): - # Save model state dict - model_state_dict = self.model.state_dict() - optim_state_dict = FullyShardedDataParallel.optim_state_dict( - self.model, self.optimizer - ) - scheduler_state_dict = self.scheduler.state_dict() - - optim_and_scheduler_state_dict = { - "optimizer": optim_state_dict, - "scheduler": scheduler_state_dict, - } - - if torch.distributed.get_rank() == 0: - # check if weights_path dir exists - weights_dir = os.path.dirname(weights_path) - if not os.path.exists(weights_dir): - print( - f"Creating weights directory {weights_dir} DOESN'T EXIST SOMEHOW" - ) - os.makedirs(weights_dir) - torch.save(model_state_dict, weights_path) - if optimizer_path is not None: - torch.save(optim_and_scheduler_state_dict, optimizer_path) - def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): - print(f"Loading Policy from {weights_path} and optimizer from {optimizer_path}") - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - - state_dict = torch.load(weights_path) - if optimizer_path is not None: - optim_data = torch.load(optimizer_path) - optimizer_state_dict = optim_data["optimizer"] - scheduler_state_dict = optim_data.get("scheduler") - else: - optimizer_state_dict = None - scheduler_state_dict = None - with FullyShardedDataParallel.state_dict_type( - self.model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=state_dict_config, - ): - # Load model weights - self.model.load_state_dict(state_dict if state_dict else None) - - # Load optimizer state - if optimizer_state_dict is not None: - optim_state_dict = FullyShardedDataParallel.shard_full_optim_state_dict( - optimizer_state_dict, self.model - ) - if self.optimizer is not None: - self.optimizer.load_state_dict(optim_state_dict) - else: - print("WARNING: initializing without optimizer") - else: - print("WARNING: No optimizer checkpoint provided") - - if scheduler_state_dict is not None: - self.scheduler.load_state_dict(scheduler_state_dict) - else: - print("WARNING: No scheduler checkpoint provided") + """Load a checkpoint into the model.""" + load_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, + ) def shutdown(self): """Shutdown the policy.""" @@ -1107,14 +1085,16 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, - offload_to_cpu: bool = True, + save_torch_dist: bool = True, + save_hf: bool = False, ): """Save a checkpoint of the model.""" futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", weights_path, optimizer_path, - offload_to_cpu=offload_to_cpu, + save_torch_dist, + save_hf, respect_tied_workers=True, ) ray.get(futures) diff --git a/nemo_reinforcer/utils/native_checkpoint.py b/nemo_reinforcer/utils/native_checkpoint.py new file mode 100644 index 0000000000..6f22ea82fd --- /dev/null +++ b/nemo_reinforcer/utils/native_checkpoint.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint management utilities for HF models.""" + +import os +from pathlib import Path +from typing import Any, Optional +import torch + +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + set_model_state_dict, + get_optimizer_state_dict, + set_optimizer_state_dict, +) + + +## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html +class ModelState(Stateful): + """Helper class for tracking model state in distributed checkpointing. + + This class is compliant with the Stateful protocol, allowing DCP to automatically + call state_dict/load_state_dict as needed in the dcp.save/load APIs. + + Args: + model: The PyTorch model to track. + """ + + def __init__(self, model): + self.model = model + + def state_dict(self): + """Get the model's state dictionary. + + Returns: + dict: Dictionary containing the model's state dict with CPU offloading enabled. + """ + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict = get_model_state_dict( + self.model, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + return model_state_dict + + def load_state_dict(self, state_dict): + """Load the state dictionary into the model. + + Args: + state_dict (dict): State dictionary to load. + """ + # sets our state dicts on the model, now that we've loaded + set_model_state_dict( + self.model, + state_dict, + ) + + +class OptimizerState(Stateful): + """Helper class for tracking optimizer state in distributed checkpointing. + + This class is compliant with the Stateful protocol, allowing DCP to automatically + call state_dict/load_state_dict as needed in the dcp.save/load APIs. + + Args: + model: The PyTorch model associated with the optimizer. + optimizer: The optimizer to track. + scheduler: Optional learning rate scheduler. + """ + + def __init__(self, model, optimizer, scheduler=None): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + + def state_dict(self): + """Get the optimizer and scheduler state dictionaries. + + Returns: + dict: Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. + """ + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + optimizer_state_dict = get_optimizer_state_dict( + self.model, + self.optimizer, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + + state_dict = { + "optim": optimizer_state_dict, + } + if self.scheduler is not None: + state_dict["sched"] = self.scheduler.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the state dictionaries into the optimizer and scheduler. + + Args: + state_dict (dict): State dictionary containing optimizer and scheduler states to load. + """ + # sets our state dicts on the optimizer, now that we've loaded + set_optimizer_state_dict( + self.model, + self.optimizer, + state_dict["optim"], + ) + + ## load the scheduler state if it exists + if "sched" in state_dict: + self.scheduler.load_state_dict(state_dict["sched"]) + + +def save_checkpoint( + model, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, + save_torch_dist: bool = True, + save_hf: bool = False, +) -> None: + """Save a checkpoint of the model and optionally optimizer state. + + Args: + model: The PyTorch model to save + weights_path: Path to save model weights + optimizer: Optional optimizer to save + scheduler: Optional scheduler to save + optimizer_path: Path to save optimizer state (required if optimizer provided) + save_torch_dist: Whether to save in PyTorch distributed format + save_hf: Whether to save in HuggingFace format + """ + if save_hf: + model_state_dict = model._fsdp_wrapped_module.state_dict() + + if torch.distributed.get_rank() == 0: + # Create a new path by appending "-hf" to the weights path + hf_weights_path = f"{Path(weights_path)}-hf" + + model.save_pretrained( + hf_weights_path, + state_dict=model_state_dict, + ) + + if save_torch_dist: + model_state = {"model": ModelState(model)} + dcp.save(model_state, checkpoint_id=weights_path) + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when saving optimizer state" + ) + optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)} + dcp.save(optimizer_state, checkpoint_id=optimizer_path) + + +def load_checkpoint( + model, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, +) -> None: + """Load a model weights and optionally optimizer state. + + Args: + model: The PyTorch model whose weights to update + weights_path: Path to load model weights from + optimizer: Optional optimizer to load state into + scheduler: Optional scheduler to load state into + optimizer_path: Path to load optimizer state from (required if optimizer provided) + """ + print(f"Loading weights from {weights_path}") + model_state_dict = {"model": ModelState(model)} + dcp.load(state_dict=model_state_dict, checkpoint_id=weights_path) + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when loading optimizer state" + ) + print(f"Loading optimizer from {optimizer_path}") + optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)} + dcp.load(state_dict=optimizer_state_dict, checkpoint_id=optimizer_path) diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index 82d263c9da..85282d64f4 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -1,5 +1,8 @@ #!/bin/bash +## clean up checkpoint directory on exit +trap "rm -rf /tmp/sft_checkpoints" EXIT + SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) # Mark the current repo as safe, since wandb fetchs metadata about the repo @@ -26,7 +29,9 @@ python -u $PROJECT_ROOT/examples/run_sft.py \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ - checkpointing.enabled=false \ + checkpointing.enabled=true \ + checkpointing.save_every_n_steps=10 \ + checkpointing.checkpoint_dir=/tmp/sft_checkpoints \ $@ \ 2>&1 | tee $RUN_LOG diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index ed90267d10..ba1bade3fc 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -29,6 +29,7 @@ basic_vllm_test_config: VllmConfig = { "backend": "vllm", "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing + "tokenizer_name": "meta-llama/Llama-3.2-1B", "dtype": "bfloat16", "max_new_tokens": 10, "temperature": 1.0, @@ -204,6 +205,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create HF-specific config with required parameters hf_config = { "model_name": basic_vllm_test_config["model_name"], + "tokenizer_name": basic_vllm_test_config["tokenizer_name"], # Required training parameters "train_global_batch_size": 4, "train_micro_batch_size": 1, @@ -507,6 +509,7 @@ def test_vllm_weight_update_and_prefix_cache_reset( hf_config = { "model_name": basic_vllm_test_config["model_name"], + "tokenizer_name": "meta-llama/Llama-3.2-1B", "train_global_batch_size": 1, "train_micro_batch_size": 1, "learning_rate": 1e-6, diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 76926960cf..7cde591049 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -28,6 +28,7 @@ basic_llama_test_config: PolicyConfig = { "model_name": "meta-llama/Llama-3.2-1B", + "tokenizer_name": "meta-llama/Llama-3.2-1B", "generation_batch_size": 1, # Small batch size for testing "train_global_batch_size": 4, "train_micro_batch_size": 1, diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py new file mode 100755 index 0000000000..8f71badea1 --- /dev/null +++ b/tests/unit/utils/test_native_checkpoint.py @@ -0,0 +1,288 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os +import pytest +import torch +from tempfile import TemporaryDirectory + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from transformers import AutoTokenizer, AutoModelForCausalLM +from nemo_reinforcer.utils.native_checkpoint import ( + load_checkpoint, + save_checkpoint, + ModelState, + OptimizerState, +) + +# Define basic test config +simple_policy_config = { + "model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "train_global_batch_size": 32, + "train_micro_batch_size": 1, + "logprob_batch_size": 1, + "max_total_sequence_length": 1024, + "precision": "float32", +} + + +@pytest.fixture +def mock_experiment(): + model = torch.nn.ModuleList( + [ + torch.nn.Linear(4, 4), + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ).to("cuda") + + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + + return model, optimizer, scheduler + + +@pytest.fixture(scope="module") +def cluster(): + """Create a virtual cluster for testing.""" + # Create a cluster with 2 GPU + virtual_cluster = RayVirtualCluster( + bundle_ct_per_node_list=[2], # 1 node with 2 GPU bundle + use_gpus=True, + max_colocated_worker_groups=1, + num_gpus_per_node=2, # Use available GPUs + name="test-cluster", + ) + yield virtual_cluster + virtual_cluster.shutdown() + + +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + tokenizer_name = simple_policy_config["tokenizer_name"] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +@pytest.fixture(scope="function") +def policy(cluster, tokenizer): + """Initialize the policy.""" + return HfPolicy( + cluster=cluster, + config=simple_policy_config, + init_optimizer=False, + init_reference_model=False, + ) + + +def get_dummy_state_dict(state_dict, dummy_dict={}): + """Recursively get the dummy state dict + by replacing tensors with random ones of the same shape. + """ + for k in state_dict.keys(): + if isinstance(state_dict[k], dict): + dummy_dict[k] = get_dummy_state_dict(state_dict[k], {}) + elif isinstance(state_dict[k], torch.Tensor): + dummy_dict[k] = torch.randn(state_dict[k].shape) + else: + dummy_dict[k] = state_dict[k] + return dummy_dict + + +def check_dict_equality(dict1, dict2): + """Recursively check equality of two dictionaries""" + for k in dict1.keys(): + if isinstance(dict1[k], dict): + check_dict_equality(dict1[k], dict2[k]) + elif isinstance(dict1[k], torch.Tensor): + assert torch.allclose(dict1[k], dict2[k]) + else: + assert dict1[k] == dict2[k] + + +def test_model_state(mock_experiment): + test_model, _, _ = mock_experiment + model_state = ModelState(test_model) + state_dict = model_state.state_dict() + + ## relu has no parameters + expected_keys = { + "0.bias", + "0.weight", + "1.bias", + "1.weight", + "3.bias", + "3.weight", + } + assert set(state_dict.keys()) == expected_keys + + dummy_model_state_dict = get_dummy_state_dict(state_dict, {}) + + ## update the model's state dict and verify that the model's parameters are updated + model_state.load_state_dict(dummy_model_state_dict) + new_model_state_dict = model_state.state_dict() + check_dict_equality(new_model_state_dict, dummy_model_state_dict) + + +def test_optimizer_state(mock_experiment): + test_model, optimizer, scheduler = mock_experiment + + optim_state = OptimizerState(test_model, optimizer, scheduler) + state_dict = optim_state.state_dict() + + assert set(state_dict.keys()) == {"optim", "sched"} + + ## relu has no parameters + expected_keys = { + "0.bias", + "0.weight", + "1.bias", + "1.weight", + "3.bias", + "3.weight", + } + + assert set(state_dict["optim"]["state"].keys()) == expected_keys + + dummy_state_dict = get_dummy_state_dict(state_dict, {}) + + optim_state.load_state_dict(dummy_state_dict) + new_state_dict = optim_state.state_dict() + check_dict_equality(new_state_dict, dummy_state_dict) + + +def test_save_and_load_model_only(mock_experiment): + test_model, _, _ = mock_experiment + + with TemporaryDirectory() as tmp_dir: + save_checkpoint(test_model, os.path.join(tmp_dir, "test_model_only")) + assert os.path.exists(os.path.join(tmp_dir, "test_model_only")) + assert not os.path.exists(os.path.join(tmp_dir, "test_model_only-hf")) + assert set(os.listdir(os.path.join(tmp_dir, "test_model_only"))) == { + ".metadata", + "__0_0.distcp", + } + + +def test_save_and_load_model_and_optimizer(mock_experiment): + test_model, optimizer, scheduler = mock_experiment + for _ in range(5): + scheduler.step() + + with TemporaryDirectory() as tmp_dir: + save_checkpoint( + test_model, + os.path.join(tmp_dir, "model_and_optimizer/model"), + optimizer, + scheduler, + optimizer_path=os.path.join(tmp_dir, "model_and_optimizer/optimizer"), + ) + + assert set(os.listdir(os.path.join(tmp_dir, "model_and_optimizer/model"))) == { + ".metadata", + "__0_0.distcp", + } + assert set( + os.listdir(os.path.join(tmp_dir, "model_and_optimizer/optimizer")) + ) == { + ".metadata", + "__0_0.distcp", + } + + ## modify the model, optimizer, and scheduler and verify that loading the checkpoint overrides the values + new_linear = torch.nn.Linear(4, 4) + new_linear.weight = torch.nn.Parameter(torch.ones([4, 4]).to("cuda")) + new_linear.bias = torch.nn.Parameter(torch.ones(4).to("cuda")) + new_model = torch.nn.ModuleList( + [ + new_linear, + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ).to("cuda") + + new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR( + new_optimizer, step_size=4, gamma=0.2 + ) + load_checkpoint( + new_model, + os.path.join(tmp_dir, "model_and_optimizer/model"), + new_optimizer, + new_scheduler, + optimizer_path=os.path.join(tmp_dir, "model_and_optimizer/optimizer"), + ) + + assert scheduler.state_dict() == new_scheduler.state_dict() + check_dict_equality(new_model.state_dict(), test_model.state_dict()) + check_dict_equality(new_optimizer.state_dict(), optimizer.state_dict()) + + +def test_save_and_load_hf_checkpoint(policy): + ## warm up with a forward pass + ## this is needed before saving a checkpoint because FSDP does some lazy initialization + input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128 + attention_mask = torch.ones(4, 128) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + dummy_fwd_dict = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "labels": torch.randint(0, 16000, (4, 128)), + } + ) + policy.get_logprobs(dummy_fwd_dict) + + with TemporaryDirectory() as tmp_dir: + policy.save_checkpoint( + os.path.join(tmp_dir, "test_hf_and_dcp"), + save_hf=True, + save_torch_dist=True, + ) + + ## make sure we save both HF and DCP checkpoints + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == { + "__0_0.distcp", + "__1_0.distcp", + ".metadata", + } + ## 1B model has two shards + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp-hf"))) == { + "config.json", + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "model.safetensors.index.json", + } + + coverted_model = AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_dir, "test_hf_and_dcp-hf") + ) + original_model = AutoModelForCausalLM.from_pretrained( + simple_policy_config["model_name"] + ) + + ## make sure converted model matches the original + check_dict_equality(coverted_model.state_dict(), original_model.state_dict()) + + policy.worker_group.shutdown()