diff --git a/nemo_rl/models/megatron/community_import.py b/nemo_rl/models/megatron/community_import.py index 271cda579c..2e52a54f9c 100644 --- a/nemo_rl/models/megatron/community_import.py +++ b/nemo_rl/models/megatron/community_import.py @@ -76,7 +76,16 @@ def import_model_from_hf_name( model_provider.pipeline_dtype = megatron_config["pipeline_dtype"] model_provider.sequence_parallel = megatron_config["sequence_parallel"] model_provider.finalize() - model_provider.initialize_model_parallel(seed=0) + + from megatron.core import parallel_state + + if not parallel_state.model_parallel_is_initialized(): + model_provider.initialize_model_parallel(seed=0) + else: + from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed + + model_parallel_cuda_manual_seed(0) + megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False) # The above parallelism settings are used to load the model in a distributed manner. diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index cf6bdca734..0a81bd8459 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -50,8 +50,7 @@ time uv run --no-sync bash ./tests/functional/sft_megatron.sh time uv run --no-sync bash ./tests/functional/sft_megatron_lora.sh time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh time uv run --no-sync bash ./tests/functional/test_automodel_extra_installed_correctly.sh -# Re-enable once DTensor v2 converter is fixed. -# time uv run --no-sync bash ./tests/functional/test_converters.sh +time uv run --no-sync bash ./tests/functional/test_converters.sh time uv run --no-sync bash ./tests/functional/test_mcore_extra_installed_correctly.sh time uv run --no-sync bash ./tests/functional/vlm_grpo.sh diff --git a/tests/functional/test_converter_roundtrip.py b/tests/functional/test_converter_roundtrip.py index afe6ef0659..5332f7835f 100644 --- a/tests/functional/test_converter_roundtrip.py +++ b/tests/functional/test_converter_roundtrip.py @@ -24,8 +24,10 @@ 5. Asserts that the converted DCP and Megatron checkpoints are identical and match the original HF checkpoint """ +import copy import os import tempfile +import time from typing import Any, Dict import torch @@ -70,7 +72,9 @@ def create_test_config() -> Dict[str, Any]: "train_micro_batch_size": 2, "max_total_sequence_length": 128, "precision": "bfloat16", + "offload_optimizer_for_logprob": False, "dtensor_cfg": { + "_v2": False, "enabled": True, "cpu_offload": False, "sequence_parallel": False, @@ -182,6 +186,15 @@ def assert_state_dicts_equal( print(f"✓ {name1} and {name2} are identical") +def check_file_exists(path: str) -> bool: + """Check if a file exists.""" + for _ in range(10): + if os.path.exists(path): + return True + time.sleep(0.5) + return False + + def create_dcp_checkpoint( model_name: str, config: Dict[str, Any], temp_dir: str ) -> str: @@ -209,8 +222,18 @@ def create_dcp_checkpoint( ) # Save checkpoint without any training - dcp_checkpoint_path = os.path.join(temp_dir, "dcp_checkpoint") - policy.save_checkpoint(dcp_checkpoint_path) + use_v2 = config["policy"]["dtensor_cfg"]["_v2"] + dcp_checkpoint_path = os.path.join( + temp_dir, "dcp_checkpoint" + ("_v2" if use_v2 else "_v1") + ) + policy.save_checkpoint( + dcp_checkpoint_path, checkpointing_cfg=config["checkpointing"] + ) + + if not check_file_exists(dcp_checkpoint_path): + raise FileNotFoundError( + f"DCP checkpoint creation failed at {dcp_checkpoint_path}" + ) print(f"✓ DCP checkpoint saved to: {dcp_checkpoint_path}") return dcp_checkpoint_path @@ -220,8 +243,21 @@ def create_megatron_checkpoint(model_name: str, temp_dir: str) -> str: """Create a Megatron checkpoint using community import.""" print("Creating Megatron checkpoint...") - megatron_checkpoint_path = os.path.join(temp_dir, "megatron_checkpoint") - import_model_from_hf_name(model_name, megatron_checkpoint_path) + try: + from megatron.bridge.training.model_load_save import ( + temporary_distributed_context, + ) + except ImportError: + raise ImportError("megatron.bridge.training is not available.") + + with temporary_distributed_context(backend="gloo"): + megatron_checkpoint_path = os.path.join(temp_dir, "megatron_checkpoint") + import_model_from_hf_name(model_name, megatron_checkpoint_path) + + if not check_file_exists(megatron_checkpoint_path): + raise FileNotFoundError( + f"Megatron checkpoint creation failed at {megatron_checkpoint_path}" + ) print(f"✓ Megatron checkpoint saved to: {megatron_checkpoint_path}") return os.path.join(megatron_checkpoint_path, "iter_0000000") @@ -231,7 +267,8 @@ def convert_dcp_to_hf_checkpoint(dcp_path: str, model_name: str, temp_dir: str) """Convert DCP checkpoint to HF format.""" print("Converting DCP to HF format...") - hf_path = os.path.join(temp_dir, "dcp_to_hf") + use_v2 = dcp_path.endswith("_v2") + hf_path = os.path.join(temp_dir, "dcp_to_hf" + ("_v2" if use_v2 else "_v1")) convert_dcp_to_hf( dcp_ckpt_path=dcp_path, hf_ckpt_path=hf_path, @@ -290,43 +327,66 @@ def main(): # Step 2: Create DCP checkpoint print("\n" + "=" * 60) - print("STEP 2: Creating DCP checkpoint") + print("STEP 2: Creating Dtensor V1 DCP checkpoint") print("=" * 60) - config = create_test_config() - dcp_checkpoint_path = create_dcp_checkpoint(model_name, config, temp_dir) + config_v1 = create_test_config() + dcp_checkpoint_path_v1 = create_dcp_checkpoint(model_name, config_v1, temp_dir) - # Step 3: Create Megatron checkpoint + # Step 3: Create Dtensor V2 DCP checkpoint print("\n" + "=" * 60) - print("STEP 3: Creating Megatron checkpoint") + print("STEP 3: Creating Dtensor V2 DCP checkpoint") + print("=" * 60) + config_v2 = copy.deepcopy(config_v1) + config_v2["policy"]["dtensor_cfg"]["_v2"] = True + config_v2["checkpointing"]["model_save_format"] = "torch_save" + dcp_checkpoint_path_v2 = create_dcp_checkpoint(model_name, config_v2, temp_dir) + + # Step 4: Create Megatron checkpoint + print("\n" + "=" * 60) + print("STEP 4: Creating Megatron checkpoint") print("=" * 60) megatron_checkpoint_path = create_megatron_checkpoint(model_name, temp_dir) - # Step 4: Convert DCP to HF + # Step 5: Convert Dtensor V1 DCP to HF + print("\n" + "=" * 60) + print("STEP 5: Converting Dtensor V1 DCP to HF format") + print("=" * 60) + dcp_to_hf_path_v1 = convert_dcp_to_hf_checkpoint( + dcp_checkpoint_path_v1, model_name, temp_dir + ) + + # Step 6: Convert Dtensor V2 DCP to HF print("\n" + "=" * 60) - print("STEP 4: Converting DCP to HF format") + print("STEP 6: Converting Dtensor V2 DCP to HF format") print("=" * 60) - dcp_to_hf_path = convert_dcp_to_hf_checkpoint( - dcp_checkpoint_path, model_name, temp_dir + dcp_to_hf_path_v2 = convert_dcp_to_hf_checkpoint( + dcp_checkpoint_path_v2, model_name, temp_dir ) - # Step 5: Convert Megatron to HF + # Step 7: Convert Megatron to HF print("\n" + "=" * 60) - print("STEP 5: Converting Megatron to HF format") + print("STEP 7: Converting Megatron to HF format") print("=" * 60) megatron_to_hf_path = convert_megatron_to_hf_checkpoint( megatron_checkpoint_path, model_name, temp_dir ) - # Step 6: Load converted models and compare + # Step 8: Load converted models and compare print("\n" + "=" * 60) - print("STEP 6: Loading converted models and comparing") + print("STEP 8: Loading converted models and comparing") print("=" * 60) - # Load DCP-converted model - dcp_converted_model = AutoModelForCausalLM.from_pretrained( - dcp_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True + # Load Dtensor V1 DCP-converted model + dcp_converted_model_v1 = AutoModelForCausalLM.from_pretrained( + dcp_to_hf_path_v1, torch_dtype=torch.bfloat16, trust_remote_code=True ) - dcp_converted_state_dict = get_model_state_dict(dcp_converted_model) + dcp_converted_state_dict_v1 = get_model_state_dict(dcp_converted_model_v1) + + # Load Dtensor V2 DCP-converted model + dcp_converted_model_v2 = AutoModelForCausalLM.from_pretrained( + dcp_to_hf_path_v2, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + dcp_converted_state_dict_v2 = get_model_state_dict(dcp_converted_model_v2) # Load Megatron-converted model megatron_converted_model = AutoModelForCausalLM.from_pretrained( @@ -334,29 +394,53 @@ def main(): ) megatron_converted_state_dict = get_model_state_dict(megatron_converted_model) - # Step 7: Assertions + # Step 9: Assertions print("\n" + "=" * 60) - print("STEP 7: Running assertions") + print("STEP 9: Running assertions") print("=" * 60) - # Compare DCP-converted vs Megatron-converted - print("Comparing DCP-converted HF model with Megatron-converted HF model...") + # Compare Dtensor V1 DCP-converted vs Original HF model + print("Comparing Dtensor V1 DCP-converted HF model with Original HF model...") + assert_state_dicts_equal( + dcp_converted_state_dict_v1, + original_state_dict, + "Dtensor V1 DCP-converted HF model", + "Original HF model", + ) + + # Compare Dtensor V2 DCP-converted vs Original HF model + print("Comparing Dtensor V2 DCP-converted HF model with Original HF model...") + assert_state_dicts_equal( + dcp_converted_state_dict_v2, + original_state_dict, + "Dtensor V2 DCP-converted HF model", + "Original HF model", + ) + + # Compare Megatron-converted vs Original HF model + print("Comparing Megatron-converted HF model with Original HF model...") assert_state_dicts_equal( - dcp_converted_state_dict, megatron_converted_state_dict, - "DCP-converted HF model", + original_state_dict, "Megatron-converted HF model", + "Original HF model", ) - print("✓ DCP and Megatron roundtrip checkpoints are identical!") + print( + "✓ Dtensor V1 and Dtensor V2 DCP and Megatron roundtrip checkpoints are identical!" + ) # Verify that both converted models have the expected structure expected_keys = set(original_state_dict.keys()) - dcp_keys = set(dcp_converted_state_dict.keys()) + dcp_keys_v1 = set(dcp_converted_state_dict_v1.keys()) + dcp_keys_v2 = set(dcp_converted_state_dict_v2.keys()) megatron_keys = set(megatron_converted_state_dict.keys()) - assert dcp_keys == expected_keys, ( - f"DCP converted model missing keys: {expected_keys - dcp_keys}" + assert dcp_keys_v1 == expected_keys, ( + f"Dtensor V1 DCP converted model missing keys: {expected_keys - dcp_keys_v1}" + ) + assert dcp_keys_v2 == expected_keys, ( + f"Dtensor V2 DCP converted model missing keys: {expected_keys - dcp_keys_v2}" ) assert megatron_keys == expected_keys, ( f"Megatron converted model missing keys: {expected_keys - megatron_keys}" @@ -369,10 +453,13 @@ def main(): test_input = torch.randint(0, 1000, (1, 10)) with torch.no_grad(): - dcp_output = dcp_converted_model(test_input) + dcp_output_v1 = dcp_converted_model_v1(test_input) + dcp_output_v2 = dcp_converted_model_v2(test_input) megatron_output = megatron_converted_model(test_input) - print("✓ Both converted models can perform forward passes") + print( + "✓ Dtensor V1 and Dtensor V2 DCP and Megatron converted models can perform forward passes" + ) print("\n" + "=" * 80) print("✓ ALL TESTS PASSED!")