Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6b10bec
switch to torch distributed checkpointing
ashors1 Mar 26, 2025
83ca3a0
save model weights and optimizer states to separate directories
ashors1 Mar 31, 2025
4fbb313
only load optimizer if path provided
ashors1 Mar 31, 2025
05bd71e
option to save model weights in hf format
ashors1 Apr 1, 2025
94d9227
address comments
ashors1 Apr 1, 2025
16f070f
refactor
ashors1 Apr 1, 2025
7ad920f
remove unused imports
ashors1 Apr 1, 2025
bb35d8e
update imports
ashors1 Apr 1, 2025
10a25a4
support saving both HF and torch DCP checkpoints
ashors1 Apr 1, 2025
f712f66
cleanup, add tokenizer_name to config
ashors1 Apr 1, 2025
e2c264c
add example conversion script
ashors1 Apr 1, 2025
65a2715
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dist-…
ashors1 Apr 1, 2025
2affa7c
copyright
ashors1 Apr 1, 2025
b05df3d
address comments
ashors1 Apr 1, 2025
bb902c0
set is_main_process during hf save
ashors1 Apr 2, 2025
aac057f
save to absolute path
ashors1 Apr 2, 2025
78c2a91
add documentation and save hf checkpoint at the end of training
ashors1 Apr 2, 2025
874d446
formatting
ashors1 Apr 2, 2025
a422db4
capitalization
ashors1 Apr 2, 2025
848b654
add hf checkpointing unit tests
ashors1 Apr 2, 2025
8ca4dbd
add copyright
ashors1 Apr 2, 2025
163c6c9
linting
ashors1 Apr 2, 2025
83785ec
fix tests
ashors1 Apr 2, 2025
7eee568
add HF checkpoint save test
ashors1 Apr 3, 2025
1cb9483
fix tests and improve docstrings
ashors1 Apr 3, 2025
3affdfd
address comments
ashors1 Apr 3, 2025
174ef81
cleanup
ashors1 Apr 3, 2025
f6e56d0
make HF save multi-process safe
ashors1 Apr 3, 2025
58a0d3d
hf checkpointing fix
ashors1 Apr 4, 2025
fe3f47c
enable checkpointing in sft functional test
ashors1 Apr 4, 2025
eb6eca7
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dist-…
ashors1 Apr 5, 2025
4f17431
update unit test
ashors1 Apr 7, 2025
4e11a2e
address comments
ashors1 Apr 7, 2025
57c5811
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dist-…
ashors1 Apr 7, 2025
7412223
small fixes
ashors1 Apr 7, 2025
0749217
fix trap command
ashors1 Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/design_docs/checkpointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Checkpointing with HuggingFace Models

## Checkpoint Format
Reinforcer provides two checkpoint formats for HuggingFace models: Torch distributed and HuggingFace format. Torch distributed is used by default for efficiency, and HuggingFace format is provided for compatibility with HuggingFace's `AutoModel.from_pretrained` API. Note that HuggingFace format checkpoints save only the model weights, ignoring the optimizer states. It is recommended to use Torch distributed format to save intermediate checkpoints and to save a HuggingFace checkpoint only at the end of training.

There are two ways to get a Reinforcer checkpoint in HuggingFace format.

1. (Recommended) Save the HuggingFace checkpoint directly by passing `save_hf=True` to `HFPolicy`'s `save_checkpoint`:

```python
policy.save_checkpoint(
weights_path=<WHERE_TO_SAVE_MODEL_WEIGHTS>,
optimizer_path=<WHERE_TO_SAVE_OPTIM_STATE>,
save_torch_dist=True,
save_hf=True,
)
```
2. Convert a Torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose.

```python
uv run examples/convert_dcp_to_hf.py --config=<YAML CONFIG USED DURING TRAINING> <ANY CONFIG OVERRIDES USED DURING TRAINING> --dcp-ckpt-path=<PATH TO DIST CHECKPOINT TO CONVERT> --hf-ckpt-path=<WHERE TO SAVE HF CHECKPOINT>
```
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ design_docs/logger.md
design_docs/uv.md
design_docs/chat_datasets.md
design_docs/generation.md
design_docs/checkpointing.md
```
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ checkpointing:

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
Comment thread
ashors1 marked this conversation as resolved.
train_global_batch_size: 512
train_micro_batch_size: 4
generation_batch_size: 32 # Only used when generating using HF backend
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ grpo:

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ checkpointing:

policy:
model_name: "meta-llama/Llama-3.2-1B"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 1024
Expand Down
92 changes: 92 additions & 0 deletions examples/convert_dcp_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import json

from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster
from nemo_reinforcer.models.policy.hf_policy import HfPolicy
from nemo_reinforcer.utils.config import load_config


def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Convert Torch DCP checkpoint to HF checkpoint"
)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to config.json file in the checkpoint directory",
)
parser.add_argument(
"--dcp-ckpt-path", type=str, default=None, help="Path to DCP checkpoint"
)
parser.add_argument(
"--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint"
)
# Parse known args for the script
args = parser.parse_args()

return args


def main():
"""Main entry point."""
args = parse_args()

with open(args.config, "r") as f:
config = json.load(f)

dcp_ckpt = args.dcp_ckpt_path
hf_ckpt = args.hf_ckpt_path

# Extract individual configs for easier access
policy_config = config["policy"]
cluster_config = config["cluster"]

init_ray()

cluster = RayVirtualCluster(
Comment thread
terrykong marked this conversation as resolved.
name="convert_cluster",
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
* cluster_config["num_nodes"],
use_gpus=True,
num_gpus_per_node=cluster_config["gpus_per_node"],
max_colocated_worker_groups=1,
)

policy = HfPolicy(
cluster=cluster,
config=policy_config,
weights_path=dcp_ckpt,
init_optimizer=False,
)

policy.save_checkpoint(
weights_path=os.path.abspath(hf_ckpt),
save_hf=True,
save_torch_dist=False,
)

print(f"Saved HF checkpoint to: {hf_ckpt}-hf")

cluster.shutdown()
policy.worker_group.shutdown()


if __name__ == "__main__":
main()
18 changes: 14 additions & 4 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ def setup(
policy = HfPolicy(
cluster=cluster,
config=policy_config,
weights_path=Path(last_checkpoint_path) / "policy.pt"
weights_path=Path(last_checkpoint_path) / "policy" / "weights"
if last_checkpoint_path
else None,
optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt"
optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer"
if last_checkpoint_path
else None,
init_optimizer=True,
Expand Down Expand Up @@ -608,6 +608,13 @@ def grpo_train(
and (step + 1) % master_config["checkpointing"]["save_period"] == 0
): # +1 because step is 0-indexed
policy.prepare_for_training()

is_last_checkpoint = (
min(len(dataloader), master_config["grpo"]["max_num_steps"])
- (step + 1)
< master_config["checkpointing"]["save_period"]
)

grpo_save_state["step"] = step + 1
grpo_save_state["val_reward"] = val_metrics["accuracy"]
grpo_save_state["consumed_samples"] = consumed_samples
Expand All @@ -617,8 +624,11 @@ def grpo_train(
step + 1, grpo_save_state, master_config
)
policy.save_checkpoint(
os.path.join(checkpoint_path, "policy.pt"),
os.path.join(checkpoint_path, "policy_optimizer.pt"),
weights_path=os.path.join(checkpoint_path, "policy", "weights"),
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
dataloader.state_dict(),
Expand Down
27 changes: 16 additions & 11 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ def setup(
policy = HfPolicy(
cluster=cluster,
config=policy_config,
weights_path=Path(last_checkpoint_path) / "policy.pt"
weights_path=Path(last_checkpoint_path) / "policy" / "weights"
if last_checkpoint_path
else None,
optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt"
Comment thread
terrykong marked this conversation as resolved.
optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer"
if last_checkpoint_path
else None,
init_optimizer=True,
Expand Down Expand Up @@ -311,9 +311,7 @@ def sft_train(
sft_save_state = _default_sft_save_state()
step = 0
else:
step = (
sft_save_state["step"] + 1
) # N+1 because the checkpoint is _after_ SFT iteration N
step = sft_save_state["step"]

sft_config = master_config["sft"]
# Validation configuration
Expand Down Expand Up @@ -399,19 +397,26 @@ def sft_train(
master_config["checkpointing"]["enabled"]
and (step + 1) % master_config["checkpointing"]["save_period"] == 0
): # +1 because step is 0-indexed
sft_save_state["step"] = step
is_last_checkpoint = (
min(len(train_dataloader), master_config["sft"]["max_num_steps"])
- (step + 1)
< master_config["checkpointing"]["save_period"]
)

sft_save_state["step"] = step + 1
sft_save_state["val_loss"] = val_metrics["val_loss"]
with timer.time("checkpointing"):
print(f"Saving checkpoint for step {step + 1}...")
checkpoint_path = checkpointer.init_tmp_checkpoint(
step + 1, sft_save_state, master_config
)

policy.save_checkpoint(
os.path.join(checkpoint_path, "policy.pt"),
os.path.join(checkpoint_path, "policy_optimizer.pt"),
## NOTE: below is a workaround to avoid a bug with checkpointing
## this should be removed once the bug is fixed
offload_to_cpu=False,
weights_path=os.path.join(checkpoint_path, "policy", "weights"),
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
train_dataloader.state_dict(),
Expand Down
1 change: 1 addition & 0 deletions nemo_reinforcer/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class PolicyConfig(TypedDict):
model_name: str
tokenizer_name: str
train_global_batch_size: int
train_micro_batch_size: int
learning_rate: float
Expand Down
Loading