-
Notifications
You must be signed in to change notification settings - Fork 356
feat: Distributed checkpointing #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
6b10bec
switch to torch distributed checkpointing
ashors1 83ca3a0
save model weights and optimizer states to separate directories
ashors1 4fbb313
only load optimizer if path provided
ashors1 05bd71e
option to save model weights in hf format
ashors1 94d9227
address comments
ashors1 16f070f
refactor
ashors1 7ad920f
remove unused imports
ashors1 bb35d8e
update imports
ashors1 10a25a4
support saving both HF and torch DCP checkpoints
ashors1 f712f66
cleanup, add tokenizer_name to config
ashors1 e2c264c
add example conversion script
ashors1 65a2715
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dist-…
ashors1 2affa7c
copyright
ashors1 b05df3d
address comments
ashors1 bb902c0
set is_main_process during hf save
ashors1 aac057f
save to absolute path
ashors1 78c2a91
add documentation and save hf checkpoint at the end of training
ashors1 874d446
formatting
ashors1 a422db4
capitalization
ashors1 848b654
add hf checkpointing unit tests
ashors1 8ca4dbd
add copyright
ashors1 163c6c9
linting
ashors1 83785ec
fix tests
ashors1 7eee568
add HF checkpoint save test
ashors1 1cb9483
fix tests and improve docstrings
ashors1 3affdfd
address comments
ashors1 174ef81
cleanup
ashors1 f6e56d0
make HF save multi-process safe
ashors1 58a0d3d
hf checkpointing fix
ashors1 fe3f47c
enable checkpointing in sft functional test
ashors1 eb6eca7
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dist-…
ashors1 4f17431
update unit test
ashors1 4e11a2e
address comments
ashors1 57c5811
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/dist-…
ashors1 7412223
small fixes
ashors1 0749217
fix trap command
ashors1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # Checkpointing with HuggingFace Models | ||
|
|
||
| ## Checkpoint Format | ||
| Reinforcer provides two checkpoint formats for HuggingFace models: Torch distributed and HuggingFace format. Torch distributed is used by default for efficiency, and HuggingFace format is provided for compatibility with HuggingFace's `AutoModel.from_pretrained` API. Note that HuggingFace format checkpoints save only the model weights, ignoring the optimizer states. It is recommended to use Torch distributed format to save intermediate checkpoints and to save a HuggingFace checkpoint only at the end of training. | ||
|
|
||
| There are two ways to get a Reinforcer checkpoint in HuggingFace format. | ||
|
|
||
| 1. (Recommended) Save the HuggingFace checkpoint directly by passing `save_hf=True` to `HFPolicy`'s `save_checkpoint`: | ||
|
|
||
| ```python | ||
| policy.save_checkpoint( | ||
| weights_path=<WHERE_TO_SAVE_MODEL_WEIGHTS>, | ||
| optimizer_path=<WHERE_TO_SAVE_OPTIM_STATE>, | ||
| save_torch_dist=True, | ||
| save_hf=True, | ||
| ) | ||
| ``` | ||
| 2. Convert a Torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose. | ||
|
|
||
| ```python | ||
| uv run examples/convert_dcp_to_hf.py --config=<YAML CONFIG USED DURING TRAINING> <ANY CONFIG OVERRIDES USED DURING TRAINING> --dcp-ckpt-path=<PATH TO DIST CHECKPOINT TO CONVERT> --hf-ckpt-path=<WHERE TO SAVE HF CHECKPOINT> | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import os | ||
| import json | ||
|
|
||
| from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster | ||
| from nemo_reinforcer.models.policy.hf_policy import HfPolicy | ||
| from nemo_reinforcer.utils.config import load_config | ||
|
|
||
|
|
||
| def parse_args(): | ||
| """Parse command line arguments.""" | ||
| parser = argparse.ArgumentParser( | ||
| description="Convert Torch DCP checkpoint to HF checkpoint" | ||
| ) | ||
| parser.add_argument( | ||
| "--config", | ||
| type=str, | ||
| default=None, | ||
| help="Path to config.json file in the checkpoint directory", | ||
| ) | ||
| parser.add_argument( | ||
| "--dcp-ckpt-path", type=str, default=None, help="Path to DCP checkpoint" | ||
| ) | ||
| parser.add_argument( | ||
| "--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint" | ||
| ) | ||
| # Parse known args for the script | ||
| args = parser.parse_args() | ||
|
|
||
| return args | ||
|
|
||
|
|
||
| def main(): | ||
| """Main entry point.""" | ||
| args = parse_args() | ||
|
|
||
| with open(args.config, "r") as f: | ||
| config = json.load(f) | ||
|
|
||
| dcp_ckpt = args.dcp_ckpt_path | ||
| hf_ckpt = args.hf_ckpt_path | ||
|
|
||
| # Extract individual configs for easier access | ||
| policy_config = config["policy"] | ||
| cluster_config = config["cluster"] | ||
|
|
||
| init_ray() | ||
|
|
||
| cluster = RayVirtualCluster( | ||
|
terrykong marked this conversation as resolved.
|
||
| name="convert_cluster", | ||
| bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] | ||
| * cluster_config["num_nodes"], | ||
| use_gpus=True, | ||
| num_gpus_per_node=cluster_config["gpus_per_node"], | ||
| max_colocated_worker_groups=1, | ||
| ) | ||
|
|
||
| policy = HfPolicy( | ||
| cluster=cluster, | ||
| config=policy_config, | ||
| weights_path=dcp_ckpt, | ||
| init_optimizer=False, | ||
| ) | ||
|
|
||
| policy.save_checkpoint( | ||
| weights_path=os.path.abspath(hf_ckpt), | ||
| save_hf=True, | ||
| save_torch_dist=False, | ||
| ) | ||
|
|
||
| print(f"Saved HF checkpoint to: {hf_ckpt}-hf") | ||
|
|
||
| cluster.shutdown() | ||
| policy.worker_group.shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.