Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion nemo_rl/models/megatron/community_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,16 @@ def import_model_from_hf_name(
model_provider.pipeline_dtype = megatron_config["pipeline_dtype"]
model_provider.sequence_parallel = megatron_config["sequence_parallel"]
model_provider.finalize()
model_provider.initialize_model_parallel(seed=0)

from megatron.core import parallel_state

if not parallel_state.model_parallel_is_initialized():
model_provider.initialize_model_parallel(seed=0)
else:
from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed

model_parallel_cuda_manual_seed(0)

megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False)

# The above parallelism settings are used to load the model in a distributed manner.
Expand Down
3 changes: 1 addition & 2 deletions tests/functional/L1_Functional_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ time uv run --no-sync bash ./tests/functional/sft_megatron.sh
time uv run --no-sync bash ./tests/functional/sft_megatron_lora.sh
time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh
time uv run --no-sync bash ./tests/functional/test_automodel_extra_installed_correctly.sh
# Re-enable once DTensor v2 converter is fixed.
# time uv run --no-sync bash ./tests/functional/test_converters.sh
time uv run --no-sync bash ./tests/functional/test_converters.sh
time uv run --no-sync bash ./tests/functional/test_mcore_extra_installed_correctly.sh
time uv run --no-sync bash ./tests/functional/vlm_grpo.sh

Expand Down
155 changes: 121 additions & 34 deletions tests/functional/test_converter_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
5. Asserts that the converted DCP and Megatron checkpoints are identical and match the original HF checkpoint
"""

import copy
import os
import tempfile
import time
from typing import Any, Dict

import torch
Expand Down Expand Up @@ -70,7 +72,9 @@ def create_test_config() -> Dict[str, Any]:
"train_micro_batch_size": 2,
"max_total_sequence_length": 128,
"precision": "bfloat16",
"offload_optimizer_for_logprob": False,
"dtensor_cfg": {
"_v2": False,
"enabled": True,
"cpu_offload": False,
"sequence_parallel": False,
Expand Down Expand Up @@ -182,6 +186,15 @@ def assert_state_dicts_equal(
print(f"✓ {name1} and {name2} are identical")


def check_file_exists(path: str) -> bool:
"""Check if a file exists."""
for _ in range(10):
if os.path.exists(path):
return True
time.sleep(0.5)
return False


def create_dcp_checkpoint(
model_name: str, config: Dict[str, Any], temp_dir: str
) -> str:
Expand Down Expand Up @@ -209,8 +222,18 @@ def create_dcp_checkpoint(
)

# Save checkpoint without any training
dcp_checkpoint_path = os.path.join(temp_dir, "dcp_checkpoint")
policy.save_checkpoint(dcp_checkpoint_path)
use_v2 = config["policy"]["dtensor_cfg"]["_v2"]
dcp_checkpoint_path = os.path.join(
temp_dir, "dcp_checkpoint" + ("_v2" if use_v2 else "_v1")
)
policy.save_checkpoint(
dcp_checkpoint_path, checkpointing_cfg=config["checkpointing"]
)

if not check_file_exists(dcp_checkpoint_path):
raise FileNotFoundError(
f"DCP checkpoint creation failed at {dcp_checkpoint_path}"
)

print(f"✓ DCP checkpoint saved to: {dcp_checkpoint_path}")
return dcp_checkpoint_path
Expand All @@ -220,8 +243,21 @@ def create_megatron_checkpoint(model_name: str, temp_dir: str) -> str:
"""Create a Megatron checkpoint using community import."""
print("Creating Megatron checkpoint...")

megatron_checkpoint_path = os.path.join(temp_dir, "megatron_checkpoint")
import_model_from_hf_name(model_name, megatron_checkpoint_path)
try:
from megatron.bridge.training.model_load_save import (
temporary_distributed_context,
)
except ImportError:
raise ImportError("megatron.bridge.training is not available.")
Comment thread
yuki-97 marked this conversation as resolved.

with temporary_distributed_context(backend="gloo"):
megatron_checkpoint_path = os.path.join(temp_dir, "megatron_checkpoint")
import_model_from_hf_name(model_name, megatron_checkpoint_path)
Comment thread
RayenTian marked this conversation as resolved.

if not check_file_exists(megatron_checkpoint_path):
raise FileNotFoundError(
f"Megatron checkpoint creation failed at {megatron_checkpoint_path}"
)

print(f"✓ Megatron checkpoint saved to: {megatron_checkpoint_path}")
return os.path.join(megatron_checkpoint_path, "iter_0000000")
Expand All @@ -231,7 +267,8 @@ def convert_dcp_to_hf_checkpoint(dcp_path: str, model_name: str, temp_dir: str)
"""Convert DCP checkpoint to HF format."""
print("Converting DCP to HF format...")

hf_path = os.path.join(temp_dir, "dcp_to_hf")
use_v2 = dcp_path.endswith("_v2")
hf_path = os.path.join(temp_dir, "dcp_to_hf" + ("_v2" if use_v2 else "_v1"))
convert_dcp_to_hf(
dcp_ckpt_path=dcp_path,
hf_ckpt_path=hf_path,
Expand Down Expand Up @@ -290,73 +327,120 @@ def main():

# Step 2: Create DCP checkpoint
print("\n" + "=" * 60)
print("STEP 2: Creating DCP checkpoint")
print("STEP 2: Creating Dtensor V1 DCP checkpoint")
print("=" * 60)
config = create_test_config()
dcp_checkpoint_path = create_dcp_checkpoint(model_name, config, temp_dir)
config_v1 = create_test_config()
dcp_checkpoint_path_v1 = create_dcp_checkpoint(model_name, config_v1, temp_dir)

# Step 3: Create Megatron checkpoint
# Step 3: Create Dtensor V2 DCP checkpoint
print("\n" + "=" * 60)
print("STEP 3: Creating Megatron checkpoint")
print("STEP 3: Creating Dtensor V2 DCP checkpoint")
print("=" * 60)
config_v2 = copy.deepcopy(config_v1)
config_v2["policy"]["dtensor_cfg"]["_v2"] = True
config_v2["checkpointing"]["model_save_format"] = "torch_save"
dcp_checkpoint_path_v2 = create_dcp_checkpoint(model_name, config_v2, temp_dir)

# Step 4: Create Megatron checkpoint
print("\n" + "=" * 60)
print("STEP 4: Creating Megatron checkpoint")
print("=" * 60)
megatron_checkpoint_path = create_megatron_checkpoint(model_name, temp_dir)

# Step 4: Convert DCP to HF
# Step 5: Convert Dtensor V1 DCP to HF
print("\n" + "=" * 60)
print("STEP 5: Converting Dtensor V1 DCP to HF format")
print("=" * 60)
dcp_to_hf_path_v1 = convert_dcp_to_hf_checkpoint(
dcp_checkpoint_path_v1, model_name, temp_dir
)

# Step 6: Convert Dtensor V2 DCP to HF
print("\n" + "=" * 60)
print("STEP 4: Converting DCP to HF format")
print("STEP 6: Converting Dtensor V2 DCP to HF format")
print("=" * 60)
dcp_to_hf_path = convert_dcp_to_hf_checkpoint(
dcp_checkpoint_path, model_name, temp_dir
dcp_to_hf_path_v2 = convert_dcp_to_hf_checkpoint(
dcp_checkpoint_path_v2, model_name, temp_dir
)

# Step 5: Convert Megatron to HF
# Step 7: Convert Megatron to HF
print("\n" + "=" * 60)
print("STEP 5: Converting Megatron to HF format")
print("STEP 7: Converting Megatron to HF format")
print("=" * 60)
megatron_to_hf_path = convert_megatron_to_hf_checkpoint(
megatron_checkpoint_path, model_name, temp_dir
)

# Step 6: Load converted models and compare
# Step 8: Load converted models and compare
print("\n" + "=" * 60)
print("STEP 6: Loading converted models and comparing")
print("STEP 8: Loading converted models and comparing")
print("=" * 60)

# Load DCP-converted model
dcp_converted_model = AutoModelForCausalLM.from_pretrained(
dcp_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True
# Load Dtensor V1 DCP-converted model
dcp_converted_model_v1 = AutoModelForCausalLM.from_pretrained(
dcp_to_hf_path_v1, torch_dtype=torch.bfloat16, trust_remote_code=True
)
dcp_converted_state_dict = get_model_state_dict(dcp_converted_model)
dcp_converted_state_dict_v1 = get_model_state_dict(dcp_converted_model_v1)

# Load Dtensor V2 DCP-converted model
dcp_converted_model_v2 = AutoModelForCausalLM.from_pretrained(
dcp_to_hf_path_v2, torch_dtype=torch.bfloat16, trust_remote_code=True
)
dcp_converted_state_dict_v2 = get_model_state_dict(dcp_converted_model_v2)

# Load Megatron-converted model
megatron_converted_model = AutoModelForCausalLM.from_pretrained(
megatron_to_hf_path, torch_dtype=torch.bfloat16, trust_remote_code=True
)
megatron_converted_state_dict = get_model_state_dict(megatron_converted_model)

# Step 7: Assertions
# Step 9: Assertions
print("\n" + "=" * 60)
print("STEP 7: Running assertions")
print("STEP 9: Running assertions")
print("=" * 60)

# Compare DCP-converted vs Megatron-converted
print("Comparing DCP-converted HF model with Megatron-converted HF model...")
# Compare Dtensor V1 DCP-converted vs Original HF model
print("Comparing Dtensor V1 DCP-converted HF model with Original HF model...")
assert_state_dicts_equal(
dcp_converted_state_dict_v1,
original_state_dict,
"Dtensor V1 DCP-converted HF model",
"Original HF model",
)

# Compare Dtensor V2 DCP-converted vs Original HF model
print("Comparing Dtensor V2 DCP-converted HF model with Original HF model...")
assert_state_dicts_equal(
dcp_converted_state_dict_v2,
original_state_dict,
"Dtensor V2 DCP-converted HF model",
"Original HF model",
)

# Compare Megatron-converted vs Original HF model
print("Comparing Megatron-converted HF model with Original HF model...")
assert_state_dicts_equal(
dcp_converted_state_dict,
megatron_converted_state_dict,
"DCP-converted HF model",
original_state_dict,
"Megatron-converted HF model",
"Original HF model",
)

print("✓ DCP and Megatron roundtrip checkpoints are identical!")
print(
"✓ Dtensor V1 and Dtensor V2 DCP and Megatron roundtrip checkpoints are identical!"
)

# Verify that both converted models have the expected structure
expected_keys = set(original_state_dict.keys())
dcp_keys = set(dcp_converted_state_dict.keys())
dcp_keys_v1 = set(dcp_converted_state_dict_v1.keys())
dcp_keys_v2 = set(dcp_converted_state_dict_v2.keys())
megatron_keys = set(megatron_converted_state_dict.keys())

assert dcp_keys == expected_keys, (
f"DCP converted model missing keys: {expected_keys - dcp_keys}"
assert dcp_keys_v1 == expected_keys, (
f"Dtensor V1 DCP converted model missing keys: {expected_keys - dcp_keys_v1}"
)
assert dcp_keys_v2 == expected_keys, (
f"Dtensor V2 DCP converted model missing keys: {expected_keys - dcp_keys_v2}"
)
assert megatron_keys == expected_keys, (
f"Megatron converted model missing keys: {expected_keys - megatron_keys}"
Expand All @@ -369,10 +453,13 @@ def main():
test_input = torch.randint(0, 1000, (1, 10))

with torch.no_grad():
dcp_output = dcp_converted_model(test_input)
dcp_output_v1 = dcp_converted_model_v1(test_input)
dcp_output_v2 = dcp_converted_model_v2(test_input)
megatron_output = megatron_converted_model(test_input)

print("✓ Both converted models can perform forward passes")
print(
"✓ Dtensor V1 and Dtensor V2 DCP and Megatron converted models can perform forward passes"
)

print("\n" + "=" * 80)
print("✓ ALL TESTS PASSED!")
Expand Down
Loading