Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions modelopt/torch/export/plugins/hf_checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hugging Face checkpoint utility."""

import json
import os
import shutil
from pathlib import Path

import torch
from safetensors.torch import safe_open
from tqdm import tqdm


def copy_remote_code(
pretrained_model_path: str | os.PathLike,
save_directory: str | os.PathLike,
):
"""Copy remote code from pretrained model to save directory.

For models that keep configuration and modeling files as part of the checkpoint,
we need to copy them to the export directory for seamless integration with inference
frameworks.

Args:
pretrained_model_path: Path to the pretrained model.
save_directory: Path to the save directory.

Raises:
ValueError: If the pretrained model path is not a directory.
"""
hf_checkpoint_path = Path(pretrained_model_path)
save_dir = Path(save_directory)

if not hf_checkpoint_path.is_dir():
raise ValueError(
f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory."
)

for py_file in hf_checkpoint_path.glob("*.py"):
if py_file.is_file():
shutil.copy(py_file, save_dir / py_file.name)


def load_multimodal_components(
pretrained_model_path: str | os.PathLike,
) -> dict[str, torch.Tensor]:
"""Load multimodal components from safetensors file.

Args:
pretrained_model_path: Path to the pretrained model.

Returns:
A dictionary of multimodal components.
"""
hf_checkpoint_path = Path(pretrained_model_path)
if not hf_checkpoint_path.is_dir():
raise ValueError(
f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory."
)

safetensors_file = Path(hf_checkpoint_path) / "model.safetensors"
safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json"

multimodal_state_dict = {}

if safetensors_file.is_file():
print(f"Loading multimodal components from single file: {safetensors_file}")
with safe_open(safetensors_file, framework="pt") as f:
multimodal_keys = [
key
for key in f.keys() # noqa: SIM118
if key.startswith(("multi_modal_projector", "vision_model"))
]
for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"):
multimodal_state_dict[key] = f.get_tensor(key)

elif safetensors_index_file.is_file():
print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}")
with open(safetensors_index_file) as f:
safetensors_index = json.load(f)

# For multimodal models, vision_model and multi_modal_projector are in the first shard
all_shard_files = sorted(set(safetensors_index["weight_map"].values()))
first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors"

# Load multimodal components from the first shard file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, is multimodel always in the first shard? Could it be in the first multiple shards?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. We will leave this to @yueshen2016 to improve.

safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file
print(f"Loading multimodal components from {first_shard_file}")

with safe_open(safetensors_filepath, framework="pt") as f:
shard_keys = list(f.keys())
multimodal_keys_in_shard = [
k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model"))
]

if multimodal_keys_in_shard:
print(
f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}"
)
for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"):
multimodal_state_dict[key] = f.get_tensor(key)
else:
print(f"No multimodal components found in {first_shard_file}")

Comment on lines +91 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Multimodal components may span multiple shards — only loading the first shard is fragile.

The code assumes all multimodal tensors reside in the first shard (line 96-98), but model.safetensors.index.json already provides the weight_map that maps each key to its exact shard file. If a model's multimodal weights are split across shards, this silently returns an incomplete state dict.

Suggested fix: use the weight_map to load from the correct shards
     elif safetensors_index_file.is_file():
         print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}")
         with open(safetensors_index_file) as f:
             safetensors_index = json.load(f)
 
-        # For multimodal models, vision_model and multi_modal_projector are in the first shard
-        all_shard_files = sorted(set(safetensors_index["weight_map"].values()))
-        first_shard_file = all_shard_files[0]  # e.g., "model-00001-of-00050.safetensors"
-
-        # Load multimodal components from the first shard file
-        safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file
-        print(f"Loading multimodal components from {first_shard_file}")
-
-        with safe_open(safetensors_filepath, framework="pt") as f:
-            shard_keys = list(f.keys())
-            multimodal_keys_in_shard = [
-                k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model"))
-            ]
-
-            if multimodal_keys_in_shard:
-                print(
-                    f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}"
-                )
-                for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"):
-                    multimodal_state_dict[key] = f.get_tensor(key)
-            else:
-                print(f"No multimodal components found in {first_shard_file}")
+        # Find shards that contain multimodal keys using the weight_map
+        multimodal_shard_keys = {}
+        for key, shard_file in safetensors_index["weight_map"].items():
+            if key.startswith(("multi_modal_projector", "vision_model")):
+                multimodal_shard_keys.setdefault(shard_file, []).append(key)
+
+        for shard_file, keys in multimodal_shard_keys.items():
+            safetensors_filepath = Path(hf_checkpoint_path) / shard_file
+            print(f"Loading {len(keys)} multimodal tensors from {shard_file}")
+            with safe_open(safetensors_filepath, framework="pt") as f:
+                for key in tqdm(keys, desc=f"Loading from {shard_file}"):
+                    multimodal_state_dict[key] = f.get_tensor(key)
+
+        if not multimodal_shard_keys:
+            print("No multimodal components found in weight_map")
🤖 Prompt for AI Agents
In `@modelopt/torch/export/plugins/hf_checkpoint_utils.py` around lines 91 - 118,
The current code only reads the first shard file to populate
multimodal_state_dict using multimodal_keys_in_shard; instead, iterate the
safetensors_index["weight_map"] to find all keys whose weight_map value exists
and whose key startswith ("multi_modal_projector", "vision_model"), group those
keys by their shard filename, then for each shard
(Path(hf_checkpoint_path)/shard_name) open it with safe_open and load only the
mapped keys into multimodal_state_dict (use f.get_tensor(key)); ensure you
handle duplicate shard filenames, preserve progress reporting (tqdm) and fall
back gracefully if a mapped shard file is missing.

else:
print(f"Warning: No safetensors files found in {hf_checkpoint_path}")

print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors")
return multimodal_state_dict
53 changes: 53 additions & 0 deletions modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,59 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike):
json.dump(safetensor_index, f, indent=4)


def save_safetensors_by_layer_index(
layer_state_dicts: dict[int, dict[str, torch.Tensor]],
total_layers: int,
save_directory: str | os.PathLike,
name_template: str = "model-{:05d}-of-{:05d}",
):
"""Save safetensors by layer index.

Args:
layer_state_dicts: A dictionary of layer state dictionaries.
total_layers: Total number of layers.
save_directory: Path to the save directory.
name_template: Template for the filename.
"""
for layer_index, layer_state_dict in layer_state_dicts.items():
filename = name_template.format(layer_index, total_layers)
meta_filename = filename + ".json"
ckpt_filename = filename + ".safetensors"

weight_map = {}
layer_total_size = 0
for key, val in layer_state_dict.items():
tensor_size = val.numel() * val.element_size()
layer_total_size += tensor_size
weight_map[key] = ckpt_filename

with open(save_directory + "/" + meta_filename, "w") as f:
json.dump(
{"metadata": {"total_size": layer_total_size}, "weight_map": weight_map},
f,
indent=4,
)
save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

# [TODO]: this global barrier needs to be replaced with something safer
torch.distributed.barrier()

if torch.distributed.get_rank() == 0:
safetensor_index = {
"metadata": {"total_size": 0},
"weight_map": {},
}
for layer_index in range(total_layers):
meta_filename = name_template.format(layer_index + 1, total_layers) + ".json"
with open(save_directory + "/" + meta_filename) as f:
shard = json.load(f)
safetensor_index["metadata"]["total_size"] += shard["metadata"]["total_size"]
safetensor_index["weight_map"].update(shard["weight_map"])

with open(save_directory + "/model.safetensors.index.json", "w") as f:
json.dump(safetensor_index, f, indent=4)


def _get_safetensors_file(pretrained_model_path: str | Path, key: str) -> Path | None:
"""Given a tensor key return the safetensors file that contains this tensor if exists.

Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/export/plugins/vllm_fakequant_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class VllmFqGPTModelExporter(GPTModelExporter):
def save_pretrained(
self,
save_directory: str | os.PathLike,
pretrained_model_name_or_path: str | os.PathLike | None = None,
pretrained_model_name_or_path: str | os.PathLike,
):
os.makedirs(save_directory, exist_ok=True)
gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory)
Expand All @@ -91,7 +91,7 @@ def _get_quantization_format(self, module: torch.nn.Module):

def export_mcore_gpt_to_hf_vllm_fq(
model: torch.nn.Module,
pretrained_model_name_or_path: str | os.PathLike | None = None,
pretrained_model_name_or_path: str | os.PathLike,
export_extra_modules: bool = False,
dtype: torch.dtype = torch.bfloat16,
export_dir: Path | str = tempfile.gettempdir(),
Expand Down
Loading