From 75b8edf466dea21099fae3db7e0c927543a9c6cb Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Wed, 25 Jun 2025 10:55:23 -0700 Subject: [PATCH 1/6] add megatron --> hf conversion support Signed-off-by: Anna Shors --- docs/design-docs/checkpointing.md | 4 +- docs/guides/eval.md | 4 +- docs/guides/grpo-deepscaler.md | 2 +- docs/guides/sft-openmathinstruct2.md | 2 +- .../{ => converters}/convert_dcp_to_hf.py | 0 examples/converters/convert_megatron_to_hf.py | 69 +++++++++++++++++++ nemo_rl/models/megatron/community_import.py | 38 ++++++++++ 7 files changed, 113 insertions(+), 6 deletions(-) rename examples/{ => converters}/convert_dcp_to_hf.py (100%) create mode 100644 examples/converters/convert_megatron_to_hf.py diff --git a/docs/design-docs/checkpointing.md b/docs/design-docs/checkpointing.md index de7fb64fbe..5d3feae680 100644 --- a/docs/design-docs/checkpointing.md +++ b/docs/design-docs/checkpointing.md @@ -5,7 +5,7 @@ NeMo RL provides two checkpoint formats for Hugging Face models: Torch distribut A checkpoint converter is provided to convert a Torch distributed checkpoint checkpoint to Hugging Face format after training: ```sh -uv run examples/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= +uv run examples/converters/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= ``` Usually Hugging Face checkpoints keep the weights and tokenizer together (which we also recommend for provenance). You can copy it afterwards. Here's an end-to-end example: @@ -14,6 +14,6 @@ Usually Hugging Face checkpoints keep the weights and tokenizer together (which # Change to your appropriate checkpoint directory CKPT_DIR=results/sft/step_10 -uv run examples/convert_dcp_to_hf.py --config=$CKPT_DIR/config.yaml --dcp-ckpt-path=$CKPT_DIR/policy/weights --hf-ckpt-path=${CKPT_DIR}-hf +uv run examples/converters/convert_dcp_to_hf.py --config=$CKPT_DIR/config.yaml --dcp-ckpt-path=$CKPT_DIR/policy/weights --hf-ckpt-path=${CKPT_DIR}-hf rsync -ahP $CKPT_DIR/policy/tokenizer ${CKPT_DIR}-hf/ ``` diff --git a/docs/guides/eval.md b/docs/guides/eval.md index b6e312f574..0281bb21f7 100644 --- a/docs/guides/eval.md +++ b/docs/guides/eval.md @@ -9,11 +9,11 @@ To prepare for evaluation, first ensure your model is in the correct format, whi ### Convert DCP to HF (Optional) If you have trained a model and saved the checkpoint in the Pytorch DCP format, you first need to convert it to the Hugging Face format before running evaluation. -Use the `examples/convert_dcp_to_hf.py` script. You'll need the path to the training configuration file (`config.yaml`), the DCP checkpoint directory, and specify an output path for the HF format model. +Use the `examples/converters/convert_dcp_to_hf.py` script. You'll need the path to the training configuration file (`config.yaml`), the DCP checkpoint directory, and specify an output path for the HF format model. ```sh # Example for a GRPO checkpoint at step 170 -uv run python examples/convert_dcp_to_hf.py \ +uv run python examples/converters/convert_dcp_to_hf.py \ --config results/grpo/step_170/config.yaml \ --dcp-ckpt-path results/grpo/step_170/policy/weights/ \ --hf-ckpt-path results/grpo/hf diff --git a/docs/guides/grpo-deepscaler.md b/docs/guides/grpo-deepscaler.md index 5beddf1689..456b2f2d8b 100644 --- a/docs/guides/grpo-deepscaler.md +++ b/docs/guides/grpo-deepscaler.md @@ -16,7 +16,7 @@ uv run examples/run_grpo_math.py --config=examples/configs/grpo-deepscaler-1.5b- At the end of each stage, you need to specify the Hugging Face checkpoint to continue training with. To get this checkpoint, we convert a model checkpoint to a Hugging Face checkpoint with the following command: ```sh -uv run examples/convert_dcp_to_hf.py --config=results/grpo-deepscaler-1.5b-8K/step_240/config.yaml --dcp-ckpt-path=results/grpo-deepscaler-1.5b-8K/step_240/policy/weights --hf-ckpt-path=results/grpo-deepscaler-1.5b-8K/step_240/hf +uv run examples/converters/convert_dcp_to_hf.py --config=results/grpo-deepscaler-1.5b-8K/step_240/config.yaml --dcp-ckpt-path=results/grpo-deepscaler-1.5b-8K/step_240/policy/weights --hf-ckpt-path=results/grpo-deepscaler-1.5b-8K/step_240/hf ``` When running the next command, we use the Hugging Face checkpoint as the initial checkpoint. We train with an 8K context window for 240 steps, a 16K context window for 290 steps, and a 24K context window for 50 steps. We run all experiments on a single 8XH100 80GB node or on a single 8XA100 80GB node. diff --git a/docs/guides/sft-openmathinstruct2.md b/docs/guides/sft-openmathinstruct2.md index dae8e8846d..6698c12bc0 100644 --- a/docs/guides/sft-openmathinstruct2.md +++ b/docs/guides/sft-openmathinstruct2.md @@ -26,7 +26,7 @@ The default config uses 8 GPUs (`cluster.gpus_per_node`) on 1 node (`cluster.num Throughout training, the checkpoints of the model will be saved to the `results/sft_openmathinstruct2` folder (specified by `checkpointing.checkpoint_dir`). To evaluate the model, we first need to convert the PyTorch distributed checkpoint to Hugging Face format: ``` -uv run examples/convert_dcp_to_hf.py \ +uv run examples/converters/convert_dcp_to_hf.py \ --config=results/sft_openmathinstruct2/step_1855/config.yaml \ --dcp-ckpt-path=results/sft_openmathinstruct2/step_1855/policy/weights \ --hf-ckpt-path=results/sft_openmathinstruct2/step_1855/hf diff --git a/examples/convert_dcp_to_hf.py b/examples/converters/convert_dcp_to_hf.py similarity index 100% rename from examples/convert_dcp_to_hf.py rename to examples/converters/convert_dcp_to_hf.py diff --git a/examples/converters/convert_megatron_to_hf.py b/examples/converters/convert_megatron_to_hf.py new file mode 100644 index 0000000000..c7c5d6be46 --- /dev/null +++ b/examples/converters/convert_megatron_to_hf.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import yaml + +from nemo_rl.models.megatron.community_import import export_model_from_megatron + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Convert Torch DCP checkpoint to HF checkpoint" + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to config.yaml file in the checkpoint directory", + ) + parser.add_argument( + "--megatron-ckpt-path", type=str, default=None, help="Path to Megatron checkpoint" + ) + parser.add_argument( + "--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint" + ) + # Parse known args for the script + args = parser.parse_args() + + return args + +def main(): + """Main entry point.""" + args = parse_args() + + with open(args.config, "r") as f: + config = yaml.safe_load(f) + + model_name = config["policy"]["model_name"] + # TODO: After the following PR gets merged: + # https://github.com/NVIDIA/NeMo-RL/pull/148/files + # tokenizer should be copied from policy/tokenizer/* instead of relying on the model name + # We can expose a arg at the top level --tokenizer_path to plumb that through. + # This is more stable than relying on the current NeMo-RL get_tokenizer() which can + # change release to release. + #tokenizer_name_or_path = config["policy"]["model_name"] + + export_model_from_megatron( + hf_model_name=model_name, + input_path=args.megatron_ckpt_path, + output_path=args.hf_ckpt_path, + hf_tokenizer_path = config["policy"]["tokenizer"]["name"], + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/models/megatron/community_import.py b/nemo_rl/models/megatron/community_import.py index e83922e659..6653a65ce3 100644 --- a/nemo_rl/models/megatron/community_import.py +++ b/nemo_rl/models/megatron/community_import.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os def import_model_from_hf_name(hf_model_name: str, output_path: str): if "llama" in hf_model_name.lower(): @@ -37,3 +38,40 @@ def import_model_from_hf_name(hf_model_name: str, output_path: str): import megatron.core.rerun_state_machine megatron.core.rerun_state_machine.destroy_rerun_state_machine() + +def export_model_from_megatron( + hf_model_name: str, + input_path: str, + output_path: str, + hf_tokenizer_path: str, + overwrite: bool = False, + ): + + if os.path.exists(output_path) and not overwrite: + raise FileExistsError( + f"HF checkpoint already exists at {hf_ckpt_path}. Delete it to run or set overwrite=True." + ) + + if "llama" in hf_model_name.lower(): + from nemo.tron.converter.llama import HFLlamaExporter + + print(f"Exporting model {hf_model_name} to {output_path}...") + exporter_cls = HFLlamaExporter + elif "qwen" in hf_model_name.lower(): + from nemo.tron.converter.qwen import HFQwen2Exporter + + print(f"Exporting model {hf_model_name} to {output_path}...") + exporter_cls = HFQwen2Exporter + else: + raise ValueError(f"Unknown model: {hf_model_name}") + exporter = exporter_cls( + input_path=input_path, + output_path=output_path, + hf_tokenizer_path=hf_tokenizer_path, + ) + exporter.apply() + # resetting mcore state + import megatron.core.rerun_state_machine + + megatron.core.rerun_state_machine.destroy_rerun_state_machine() + From 76416584957d330f143a56952feac721bfeb9d3d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 25 Jun 2025 16:09:58 -0700 Subject: [PATCH 2/6] cleanup Signed-off-by: ashors1 --- examples/converters/convert_megatron_to_hf.py | 16 +++++++------- nemo_rl/models/megatron/community_import.py | 21 +++++++++---------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/examples/converters/convert_megatron_to_hf.py b/examples/converters/convert_megatron_to_hf.py index c7c5d6be46..ea4501286e 100644 --- a/examples/converters/convert_megatron_to_hf.py +++ b/examples/converters/convert_megatron_to_hf.py @@ -31,7 +31,10 @@ def parse_args(): help="Path to config.yaml file in the checkpoint directory", ) parser.add_argument( - "--megatron-ckpt-path", type=str, default=None, help="Path to Megatron checkpoint" + "--megatron-ckpt-path", + type=str, + default=None, + help="Path to Megatron checkpoint", ) parser.add_argument( "--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint" @@ -41,6 +44,7 @@ def parse_args(): return args + def main(): """Main entry point.""" args = parse_args() @@ -49,19 +53,13 @@ def main(): config = yaml.safe_load(f) model_name = config["policy"]["model_name"] - # TODO: After the following PR gets merged: - # https://github.com/NVIDIA/NeMo-RL/pull/148/files - # tokenizer should be copied from policy/tokenizer/* instead of relying on the model name - # We can expose a arg at the top level --tokenizer_path to plumb that through. - # This is more stable than relying on the current NeMo-RL get_tokenizer() which can - # change release to release. - #tokenizer_name_or_path = config["policy"]["model_name"] + tokenizer_name = config["policy"]["tokenizer"]["name"] export_model_from_megatron( hf_model_name=model_name, input_path=args.megatron_ckpt_path, output_path=args.hf_ckpt_path, - hf_tokenizer_path = config["policy"]["tokenizer"]["name"], + hf_tokenizer_path=tokenizer_name, ) diff --git a/nemo_rl/models/megatron/community_import.py b/nemo_rl/models/megatron/community_import.py index 6653a65ce3..2dbfa13675 100644 --- a/nemo_rl/models/megatron/community_import.py +++ b/nemo_rl/models/megatron/community_import.py @@ -14,6 +14,7 @@ import os + def import_model_from_hf_name(hf_model_name: str, output_path: str): if "llama" in hf_model_name.lower(): from nemo.tron.converter.llama import HFLlamaImporter @@ -39,31 +40,30 @@ def import_model_from_hf_name(hf_model_name: str, output_path: str): megatron.core.rerun_state_machine.destroy_rerun_state_machine() -def export_model_from_megatron( - hf_model_name: str, - input_path: str, - output_path: str, - hf_tokenizer_path: str, - overwrite: bool = False, - ): +def export_model_from_megatron( + hf_model_name: str, + input_path: str, + output_path: str, + hf_tokenizer_path: str, + overwrite: bool = False, +): if os.path.exists(output_path) and not overwrite: raise FileExistsError( - f"HF checkpoint already exists at {hf_ckpt_path}. Delete it to run or set overwrite=True." + f"HF checkpoint already exists at {output_path}. Delete it to run or set overwrite=True." ) if "llama" in hf_model_name.lower(): from nemo.tron.converter.llama import HFLlamaExporter - print(f"Exporting model {hf_model_name} to {output_path}...") exporter_cls = HFLlamaExporter elif "qwen" in hf_model_name.lower(): from nemo.tron.converter.qwen import HFQwen2Exporter - print(f"Exporting model {hf_model_name} to {output_path}...") exporter_cls = HFQwen2Exporter else: raise ValueError(f"Unknown model: {hf_model_name}") + print(f"Exporting model {hf_model_name} to {output_path}...") exporter = exporter_cls( input_path=input_path, output_path=output_path, @@ -74,4 +74,3 @@ def export_model_from_megatron( import megatron.core.rerun_state_machine megatron.core.rerun_state_machine.destroy_rerun_state_machine() - From aee5bd824803d82d552d2f7895c95888afd5a442 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 26 Jun 2025 15:09:53 -0700 Subject: [PATCH 3/6] add test, update nemo Signed-off-by: ashors1 --- 3rdparty/NeMo-workspace/NeMo | 2 +- nemo_rl/models/megatron/community_import.py | 8 +- tests/functional/test_converter_roundtrip.py | 370 +++++++++++++++++++ tests/functional/test_converters.sh | 1 + 4 files changed, 378 insertions(+), 3 deletions(-) create mode 100644 tests/functional/test_converter_roundtrip.py create mode 100644 tests/functional/test_converters.sh diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index bab66472d2..528e3c140b 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit bab66472d2f2eb05ab621dbad66ad6031e4ee19e +Subproject commit 528e3c140b94b8bd85647a04a9671f7ae1e4c920 diff --git a/nemo_rl/models/megatron/community_import.py b/nemo_rl/models/megatron/community_import.py index 2dbfa13675..e23c477321 100644 --- a/nemo_rl/models/megatron/community_import.py +++ b/nemo_rl/models/megatron/community_import.py @@ -33,7 +33,9 @@ def import_model_from_hf_name(hf_model_name: str, output_path: str): output_path=output_path, ) else: - raise ValueError(f"Unknown model: {hf_model_name}") + raise ValueError( + f"Unknown model: {hf_model_name}. Currently, only Qwen2 and Llama are supported." + ) importer.apply() # resetting mcore state import megatron.core.rerun_state_machine @@ -62,7 +64,9 @@ def export_model_from_megatron( exporter_cls = HFQwen2Exporter else: - raise ValueError(f"Unknown model: {hf_model_name}") + raise ValueError( + f"Unknown model: {hf_model_name}. Currently, only Qwen2 and Llama are supported." + ) print(f"Exporting model {hf_model_name} to {output_path}...") exporter = exporter_cls( input_path=input_path, diff --git a/tests/functional/test_converter_roundtrip.py b/tests/functional/test_converter_roundtrip.py new file mode 100644 index 0000000000..b65a1b1a65 --- /dev/null +++ b/tests/functional/test_converter_roundtrip.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +""" +Functional test for converter roundtrip functionality. + +This test: +1. Starts with a HuggingFace Qwen/Qwen2-0.5B checkpoint +2. Converts the model to torch DCP format +3. Converts the model to Megatron format (using community import) +4. Converts both the DCP and Megatron checkpoints back to HF format +5. Asserts that the converted DCP and Megatron checkpoints are identical and match the original HF checkpoint +""" + +import os +from typing import Any, Dict + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.models.megatron.community_import import ( + export_model_from_megatron, + import_model_from_hf_name, +) +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.native_checkpoint import convert_dcp_to_hf + + +## TODO: cleanup +def create_test_config() -> Dict[str, Any]: + """Create a test configuration for SFT training.""" + return { + "sft": { + "max_num_epochs": 1, + "max_num_steps": 2, # Very short training for testing + "val_period": 2, + "val_batches": 1, + "val_global_batch_size": 4, + "val_micro_batch_size": 2, + "val_at_start": False, + "seed": 42, + }, + "checkpointing": { + "enabled": True, + "checkpoint_dir": "/tmp/test_converter_checkpoints", + "metric_name": "val_loss", + "higher_is_better": False, + "keep_top_k": 1, + "save_period": 2, + }, + "policy": { + "model_name": "Qwen/Qwen2-0.5B", + "tokenizer": {"name": "Qwen/Qwen2-0.5B"}, + "train_global_batch_size": 4, + "train_micro_batch_size": 2, + "max_total_sequence_length": 128, + "precision": "bfloat16", + "fsdp_offload_enabled": False, + "activation_checkpointing_enabled": False, + "dtensor_cfg": { + "enabled": True, + "cpu_offload": False, + "sequence_parallel": False, + "activation_checkpointing": False, + "tensor_parallel_size": 1, + "context_parallel_size": 1, + "custom_parallel_plan": None, + }, + "dynamic_batching": {"enabled": False}, + "make_sequence_length_divisible_by": 1, + "max_grad_norm": 1.0, + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5.0e-6, + "weight_decay": 0.1, + "betas": [0.9, 0.98], + "eps": 1e-5, + "foreach": False, + "fused": False, + }, + }, + "megatron_cfg": { + "enabled": False, # We'll use DCP for this test + }, + }, + "data": { + "max_input_seq_length": 128, + "dataset_name": "squad", + "add_bos": True, + "add_eos": True, + "add_generation_prompt": False, + }, + "logger": { + "log_dir": "/tmp/test_converter_logs", + "wandb_enabled": False, + "tensorboard_enabled": False, + "monitor_gpus": False, + }, + "cluster": { + "gpus_per_node": 1, + "num_nodes": 1, + }, + } + + +def load_model_and_tokenizer(model_name: str): + """Load the original HF model and tokenizer.""" + print(f"Loading original model: {model_name}") + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return model, tokenizer + + +def get_model_state_dict(model): + """Get the state dict of a model, ensuring all tensors are on CPU.""" + state_dict = model.state_dict() + cpu_state_dict = {} + for key, value in state_dict.items(): + if isinstance(value, torch.Tensor): + cpu_state_dict[key] = value.detach().cpu() + else: + cpu_state_dict[key] = value + return cpu_state_dict + + +def assert_state_dicts_equal( + state_dict1: Dict[str, Any], state_dict2: Dict[str, Any], name1: str, name2: str +): + """Assert that two state dictionaries are equal.""" + print(f"Comparing {name1} vs {name2}") + + # Check that keys match + keys1 = set(state_dict1.keys()) + keys2 = set(state_dict2.keys()) + + if keys1 != keys2: + missing_in_2 = keys1 - keys2 + missing_in_1 = keys2 - keys1 + raise AssertionError( + f"State dict keys don't match between {name1} and {name2}.\n" + f"Keys in {name1} but not in {name2}: {missing_in_2}\n" + f"Keys in {name2} but not in {name1}: {missing_in_1}" + ) + + # Check that values match + for key in keys1: + val1 = state_dict1[key] + val2 = state_dict2[key] + + if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + if not torch.allclose(val1, val2, rtol=1e-5, atol=1e-5): + max_diff = torch.max(torch.abs(val1 - val2)).item() + raise AssertionError( + f"Tensors for key '{key}' don't match between {name1} and {name2}. " + f"Max difference: {max_diff}" + ) + elif val1 != val2: + raise AssertionError( + f"Non-tensor values for key '{key}' don't match between {name1} and {name2}. " + f"{name1}: {val1}, {name2}: {val2}" + ) + + print(f"✓ {name1} and {name2} are identical") + + +def create_dcp_checkpoint( + model_name: str, config: Dict[str, Any], temp_dir: str +) -> str: + """Create a DCP checkpoint without training.""" + print("Creating DCP checkpoint...") + + # Create cluster + cluster = RayVirtualCluster( + name="test-converter-cluster", + bundle_ct_per_node_list=[1], + use_gpus=True, + num_gpus_per_node=1, + max_colocated_worker_groups=1, + ) + + # Get tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # Create policy + policy = Policy( + cluster=cluster, + config=config["policy"], + tokenizer=tokenizer, + init_reference_model=False, + ) + + # Save checkpoint without any training + dcp_checkpoint_path = os.path.join(temp_dir, "dcp_checkpoint") + policy.save_checkpoint(dcp_checkpoint_path) + + print(f"✓ DCP checkpoint saved to: {dcp_checkpoint_path}") + return dcp_checkpoint_path + + +def create_megatron_checkpoint(model_name: str, temp_dir: str) -> str: + """Create a Megatron checkpoint using community import.""" + print("Creating Megatron checkpoint...") + + megatron_checkpoint_path = os.path.join(temp_dir, "megatron_checkpoint") + import_model_from_hf_name(model_name, megatron_checkpoint_path) + + print(f"✓ Megatron checkpoint saved to: {megatron_checkpoint_path}") + return os.path.join(megatron_checkpoint_path, "iter_0000000") + + +def convert_dcp_to_hf_checkpoint(dcp_path: str, model_name: str, temp_dir: str) -> str: + """Convert DCP checkpoint to HF format.""" + print("Converting DCP to HF format...") + + hf_path = os.path.join(temp_dir, "dcp_to_hf") + convert_dcp_to_hf( + dcp_ckpt_path=dcp_path, + hf_ckpt_path=hf_path, + model_name_or_path=model_name, + tokenizer_name_or_path=model_name, + overwrite=True, + ) + + print(f"✓ DCP to HF conversion saved to: {hf_path}") + return hf_path + + +def convert_megatron_to_hf_checkpoint( + megatron_path: str, model_name: str, temp_dir: str +) -> str: + """Convert Megatron checkpoint to HF format.""" + print("Converting Megatron to HF format...") + + hf_path = os.path.join(temp_dir, "megatron_to_hf") + + # Get tokenizer for the export + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + tokenizer_path = os.path.join(temp_dir, "tokenizer") + tokenizer.save_pretrained(tokenizer_path) + + export_model_from_megatron( + hf_model_name=model_name, + input_path=megatron_path, + output_path=hf_path, + hf_tokenizer_path=tokenizer_path, + overwrite=True, + ) + + print(f"✓ Megatron to HF conversion saved to: {hf_path}") + return hf_path + + +def main(): + """Main test function.""" + print("=" * 80) + print("Starting Converter Roundtrip Functional Test") + print("=" * 80) + + # TODO(@ashors): test more models + model_name = "Qwen/Qwen2-0.5B" + + # with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = "/opt/test/test_converter_checkpoints" + print(f"Using temporary directory: {temp_dir}") + + # Step 1: Load original HF model + print("\n" + "=" * 60) + print("STEP 1: Loading original HuggingFace model") + print("=" * 60) + original_model, original_tokenizer = load_model_and_tokenizer(model_name) + original_state_dict = get_model_state_dict(original_model) + + # Step 2: Create DCP checkpoint + print("\n" + "=" * 60) + print("STEP 2: Creating DCP checkpoint") + print("=" * 60) + config = create_test_config() + dcp_checkpoint_path = create_dcp_checkpoint(model_name, config, temp_dir) + + # Step 3: Create Megatron checkpoint + print("\n" + "=" * 60) + print("STEP 3: Creating Megatron checkpoint") + print("=" * 60) + megatron_checkpoint_path = create_megatron_checkpoint(model_name, temp_dir) + + # Step 4: Convert DCP to HF + print("\n" + "=" * 60) + print("STEP 4: Converting DCP to HF format") + print("=" * 60) + dcp_to_hf_path = convert_dcp_to_hf_checkpoint( + dcp_checkpoint_path, model_name, temp_dir + ) + + # Step 5: Convert Megatron to HF + print("\n" + "=" * 60) + print("STEP 5: Converting Megatron to HF format") + print("=" * 60) + megatron_to_hf_path = convert_megatron_to_hf_checkpoint( + megatron_checkpoint_path, model_name, temp_dir + ) + + # Step 6: Load converted models and compare + print("\n" + "=" * 60) + print("STEP 6: Loading converted models and comparing") + print("=" * 60) + + # Load DCP-converted model + dcp_converted_model = AutoModelForCausalLM.from_pretrained( + dcp_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + dcp_converted_state_dict = get_model_state_dict(dcp_converted_model) + + # Load Megatron-converted model + megatron_converted_model = AutoModelForCausalLM.from_pretrained( + megatron_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + megatron_converted_state_dict = get_model_state_dict(megatron_converted_model) + + # Step 7: Assertions + print("\n" + "=" * 60) + print("STEP 7: Running assertions") + print("=" * 60) + + # Compare DCP-converted vs Megatron-converted + print("Comparing DCP-converted HF model with Megatron-converted HF model...") + assert_state_dicts_equal( + dcp_converted_state_dict, + megatron_converted_state_dict, + "DCP-converted HF model", + "Megatron-converted HF model", + ) + + print("✓ DCP and Megatron roundtrip checkpoints are identical!") + + # Verify that both converted models have the expected structure + expected_keys = set(original_state_dict.keys()) + dcp_keys = set(dcp_converted_state_dict.keys()) + megatron_keys = set(megatron_converted_state_dict.keys()) + + assert dcp_keys == expected_keys, ( + f"DCP converted model missing keys: {expected_keys - dcp_keys}" + ) + assert megatron_keys == expected_keys, ( + f"Megatron converted model missing keys: {expected_keys - megatron_keys}" + ) + + print("✓ All converted models have the expected structure") + + # Test that we can do a forward pass with both converted models + print("Testing forward passes...") + test_input = torch.randint(0, 1000, (1, 10)) + + with torch.no_grad(): + dcp_output = dcp_converted_model(test_input) + megatron_output = megatron_converted_model(test_input) + + print("✓ Both converted models can perform forward passes") + + print("\n" + "=" * 80) + print("✓ ALL TESTS PASSED!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/tests/functional/test_converters.sh b/tests/functional/test_converters.sh new file mode 100644 index 0000000000..ef789ecf90 --- /dev/null +++ b/tests/functional/test_converters.sh @@ -0,0 +1 @@ +uv run --extra mcore tests/functional/test_converter_roundtrip.py \ No newline at end of file From 2ecb001c7a1b2e258f51048a280934784b751256 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 26 Jun 2025 15:17:16 -0700 Subject: [PATCH 4/6] small test fixes Signed-off-by: ashors1 --- tests/functional/test_converter_roundtrip.py | 181 +++++++++---------- 1 file changed, 90 insertions(+), 91 deletions(-) diff --git a/tests/functional/test_converter_roundtrip.py b/tests/functional/test_converter_roundtrip.py index b65a1b1a65..e551d0e6b5 100644 --- a/tests/functional/test_converter_roundtrip.py +++ b/tests/functional/test_converter_roundtrip.py @@ -11,6 +11,7 @@ """ import os +import tempfile from typing import Any, Dict import torch @@ -26,13 +27,12 @@ from nemo_rl.utils.native_checkpoint import convert_dcp_to_hf -## TODO: cleanup def create_test_config() -> Dict[str, Any]: """Create a test configuration for SFT training.""" return { "sft": { - "max_num_epochs": 1, - "max_num_steps": 2, # Very short training for testing + "max_num_epochs": 1, ## unused, no training is actually done + "max_num_steps": 2, "val_period": 2, "val_batches": 1, "val_global_batch_size": 4, @@ -264,106 +264,105 @@ def main(): # TODO(@ashors): test more models model_name = "Qwen/Qwen2-0.5B" - # with tempfile.TemporaryDirectory() as temp_dir: - temp_dir = "/opt/test/test_converter_checkpoints" - print(f"Using temporary directory: {temp_dir}") - - # Step 1: Load original HF model - print("\n" + "=" * 60) - print("STEP 1: Loading original HuggingFace model") - print("=" * 60) - original_model, original_tokenizer = load_model_and_tokenizer(model_name) - original_state_dict = get_model_state_dict(original_model) - - # Step 2: Create DCP checkpoint - print("\n" + "=" * 60) - print("STEP 2: Creating DCP checkpoint") - print("=" * 60) - config = create_test_config() - dcp_checkpoint_path = create_dcp_checkpoint(model_name, config, temp_dir) - - # Step 3: Create Megatron checkpoint - print("\n" + "=" * 60) - print("STEP 3: Creating Megatron checkpoint") - print("=" * 60) - megatron_checkpoint_path = create_megatron_checkpoint(model_name, temp_dir) - - # Step 4: Convert DCP to HF - print("\n" + "=" * 60) - print("STEP 4: Converting DCP to HF format") - print("=" * 60) - dcp_to_hf_path = convert_dcp_to_hf_checkpoint( - dcp_checkpoint_path, model_name, temp_dir - ) + with tempfile.TemporaryDirectory() as temp_dir: + print(f"Using temporary directory: {temp_dir}") + + # Step 1: Load original HF model + print("\n" + "=" * 60) + print("STEP 1: Loading original HuggingFace model") + print("=" * 60) + original_model, original_tokenizer = load_model_and_tokenizer(model_name) + original_state_dict = get_model_state_dict(original_model) + + # Step 2: Create DCP checkpoint + print("\n" + "=" * 60) + print("STEP 2: Creating DCP checkpoint") + print("=" * 60) + config = create_test_config() + dcp_checkpoint_path = create_dcp_checkpoint(model_name, config, temp_dir) + + # Step 3: Create Megatron checkpoint + print("\n" + "=" * 60) + print("STEP 3: Creating Megatron checkpoint") + print("=" * 60) + megatron_checkpoint_path = create_megatron_checkpoint(model_name, temp_dir) + + # Step 4: Convert DCP to HF + print("\n" + "=" * 60) + print("STEP 4: Converting DCP to HF format") + print("=" * 60) + dcp_to_hf_path = convert_dcp_to_hf_checkpoint( + dcp_checkpoint_path, model_name, temp_dir + ) - # Step 5: Convert Megatron to HF - print("\n" + "=" * 60) - print("STEP 5: Converting Megatron to HF format") - print("=" * 60) - megatron_to_hf_path = convert_megatron_to_hf_checkpoint( - megatron_checkpoint_path, model_name, temp_dir - ) + # Step 5: Convert Megatron to HF + print("\n" + "=" * 60) + print("STEP 5: Converting Megatron to HF format") + print("=" * 60) + megatron_to_hf_path = convert_megatron_to_hf_checkpoint( + megatron_checkpoint_path, model_name, temp_dir + ) - # Step 6: Load converted models and compare - print("\n" + "=" * 60) - print("STEP 6: Loading converted models and comparing") - print("=" * 60) + # Step 6: Load converted models and compare + print("\n" + "=" * 60) + print("STEP 6: Loading converted models and comparing") + print("=" * 60) - # Load DCP-converted model - dcp_converted_model = AutoModelForCausalLM.from_pretrained( - dcp_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True - ) - dcp_converted_state_dict = get_model_state_dict(dcp_converted_model) + # Load DCP-converted model + dcp_converted_model = AutoModelForCausalLM.from_pretrained( + dcp_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + dcp_converted_state_dict = get_model_state_dict(dcp_converted_model) - # Load Megatron-converted model - megatron_converted_model = AutoModelForCausalLM.from_pretrained( - megatron_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True - ) - megatron_converted_state_dict = get_model_state_dict(megatron_converted_model) - - # Step 7: Assertions - print("\n" + "=" * 60) - print("STEP 7: Running assertions") - print("=" * 60) - - # Compare DCP-converted vs Megatron-converted - print("Comparing DCP-converted HF model with Megatron-converted HF model...") - assert_state_dicts_equal( - dcp_converted_state_dict, - megatron_converted_state_dict, - "DCP-converted HF model", - "Megatron-converted HF model", - ) + # Load Megatron-converted model + megatron_converted_model = AutoModelForCausalLM.from_pretrained( + megatron_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + megatron_converted_state_dict = get_model_state_dict(megatron_converted_model) + + # Step 7: Assertions + print("\n" + "=" * 60) + print("STEP 7: Running assertions") + print("=" * 60) + + # Compare DCP-converted vs Megatron-converted + print("Comparing DCP-converted HF model with Megatron-converted HF model...") + assert_state_dicts_equal( + dcp_converted_state_dict, + megatron_converted_state_dict, + "DCP-converted HF model", + "Megatron-converted HF model", + ) - print("✓ DCP and Megatron roundtrip checkpoints are identical!") + print("✓ DCP and Megatron roundtrip checkpoints are identical!") - # Verify that both converted models have the expected structure - expected_keys = set(original_state_dict.keys()) - dcp_keys = set(dcp_converted_state_dict.keys()) - megatron_keys = set(megatron_converted_state_dict.keys()) + # Verify that both converted models have the expected structure + expected_keys = set(original_state_dict.keys()) + dcp_keys = set(dcp_converted_state_dict.keys()) + megatron_keys = set(megatron_converted_state_dict.keys()) - assert dcp_keys == expected_keys, ( - f"DCP converted model missing keys: {expected_keys - dcp_keys}" - ) - assert megatron_keys == expected_keys, ( - f"Megatron converted model missing keys: {expected_keys - megatron_keys}" - ) + assert dcp_keys == expected_keys, ( + f"DCP converted model missing keys: {expected_keys - dcp_keys}" + ) + assert megatron_keys == expected_keys, ( + f"Megatron converted model missing keys: {expected_keys - megatron_keys}" + ) - print("✓ All converted models have the expected structure") + print("✓ All converted models have the expected structure") - # Test that we can do a forward pass with both converted models - print("Testing forward passes...") - test_input = torch.randint(0, 1000, (1, 10)) + # Test that we can do a forward pass with both converted models + print("Testing forward passes...") + test_input = torch.randint(0, 1000, (1, 10)) - with torch.no_grad(): - dcp_output = dcp_converted_model(test_input) - megatron_output = megatron_converted_model(test_input) + with torch.no_grad(): + dcp_output = dcp_converted_model(test_input) + megatron_output = megatron_converted_model(test_input) - print("✓ Both converted models can perform forward passes") + print("✓ Both converted models can perform forward passes") - print("\n" + "=" * 80) - print("✓ ALL TESTS PASSED!") - print("=" * 80) + print("\n" + "=" * 80) + print("✓ ALL TESTS PASSED!") + print("=" * 80) if __name__ == "__main__": From 5d2b1677ca99e65bf6820fdffa9016278df35f20 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 26 Jun 2025 15:19:24 -0700 Subject: [PATCH 5/6] improve error message Signed-off-by: ashors1 --- nemo_rl/models/megatron/community_import.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/megatron/community_import.py b/nemo_rl/models/megatron/community_import.py index e23c477321..5ad061c54a 100644 --- a/nemo_rl/models/megatron/community_import.py +++ b/nemo_rl/models/megatron/community_import.py @@ -34,7 +34,8 @@ def import_model_from_hf_name(hf_model_name: str, output_path: str): ) else: raise ValueError( - f"Unknown model: {hf_model_name}. Currently, only Qwen2 and Llama are supported." + f"Unknown model: {hf_model_name}. Currently, only Qwen2 and Llama are supported. " + "If you'd like to run with a different model, please raise an issue or consider adding your own converter." ) importer.apply() # resetting mcore state @@ -65,7 +66,8 @@ def export_model_from_megatron( exporter_cls = HFQwen2Exporter else: raise ValueError( - f"Unknown model: {hf_model_name}. Currently, only Qwen2 and Llama are supported." + f"Unknown model: {hf_model_name}. Currently, only Qwen2 and Llama are supported. " + "If you'd like to run with a different model, please raise an issue or consider adding your own converter." ) print(f"Exporting model {hf_model_name} to {output_path}...") exporter = exporter_cls( From 6e574672ae360cb9c0f80eb6b984ac4073071935 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 27 Jun 2025 15:16:39 -0700 Subject: [PATCH 6/6] update nemo Signed-off-by: ashors1 --- 3rdparty/NeMo-workspace/NeMo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 528e3c140b..4b7ded58d8 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 528e3c140b94b8bd85647a04a9671f7ae1e4c920 +Subproject commit 4b7ded58d804bf3470499c6cfa385c6fa915879d