-
Notifications
You must be signed in to change notification settings - Fork 283
Chenhany/megatron export per layer #881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Hugging Face checkpoint utility.""" | ||
|
|
||
| import json | ||
| import os | ||
| import shutil | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
| from safetensors.torch import safe_open | ||
| from tqdm import tqdm | ||
|
|
||
|
|
||
| def copy_remote_code( | ||
| pretrained_model_path: str | os.PathLike, | ||
| save_directory: str | os.PathLike, | ||
| ): | ||
| """Copy remote code from pretrained model to save directory. | ||
|
|
||
| For models that keep configuration and modeling files as part of the checkpoint, | ||
| we need to copy them to the export directory for seamless integration with inference | ||
| frameworks. | ||
|
|
||
| Args: | ||
| pretrained_model_path: Path to the pretrained model. | ||
| save_directory: Path to the save directory. | ||
|
|
||
| Raises: | ||
| ValueError: If the pretrained model path is not a directory. | ||
| """ | ||
| hf_checkpoint_path = Path(pretrained_model_path) | ||
| save_dir = Path(save_directory) | ||
|
|
||
| if not hf_checkpoint_path.is_dir(): | ||
| raise ValueError( | ||
| f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." | ||
| ) | ||
|
|
||
| for py_file in hf_checkpoint_path.glob("*.py"): | ||
| if py_file.is_file(): | ||
| shutil.copy(py_file, save_dir / py_file.name) | ||
|
|
||
|
|
||
| def load_multimodal_components( | ||
| pretrained_model_path: str | os.PathLike, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """Load multimodal components from safetensors file. | ||
|
|
||
| Args: | ||
| pretrained_model_path: Path to the pretrained model. | ||
|
|
||
| Returns: | ||
| A dictionary of multimodal components. | ||
| """ | ||
| hf_checkpoint_path = Path(pretrained_model_path) | ||
| if not hf_checkpoint_path.is_dir(): | ||
| raise ValueError( | ||
| f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." | ||
| ) | ||
|
|
||
| safetensors_file = Path(hf_checkpoint_path) / "model.safetensors" | ||
| safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json" | ||
|
|
||
| multimodal_state_dict = {} | ||
|
|
||
| if safetensors_file.is_file(): | ||
| print(f"Loading multimodal components from single file: {safetensors_file}") | ||
| with safe_open(safetensors_file, framework="pt") as f: | ||
| multimodal_keys = [ | ||
| key | ||
| for key in f.keys() # noqa: SIM118 | ||
| if key.startswith(("multi_modal_projector", "vision_model")) | ||
| ] | ||
| for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"): | ||
| multimodal_state_dict[key] = f.get_tensor(key) | ||
|
|
||
| elif safetensors_index_file.is_file(): | ||
| print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}") | ||
| with open(safetensors_index_file) as f: | ||
| safetensors_index = json.load(f) | ||
|
|
||
| # For multimodal models, vision_model and multi_modal_projector are in the first shard | ||
| all_shard_files = sorted(set(safetensors_index["weight_map"].values())) | ||
| first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors" | ||
|
|
||
| # Load multimodal components from the first shard file | ||
| safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file | ||
| print(f"Loading multimodal components from {first_shard_file}") | ||
|
|
||
| with safe_open(safetensors_filepath, framework="pt") as f: | ||
| shard_keys = list(f.keys()) | ||
| multimodal_keys_in_shard = [ | ||
| k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model")) | ||
| ] | ||
|
|
||
| if multimodal_keys_in_shard: | ||
| print( | ||
| f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}" | ||
| ) | ||
| for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"): | ||
| multimodal_state_dict[key] = f.get_tensor(key) | ||
| else: | ||
| print(f"No multimodal components found in {first_shard_file}") | ||
|
|
||
|
Comment on lines
+91
to
+118
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Multimodal components may span multiple shards — only loading the first shard is fragile. The code assumes all multimodal tensors reside in the first shard (line 96-98), but Suggested fix: use the weight_map to load from the correct shards elif safetensors_index_file.is_file():
print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}")
with open(safetensors_index_file) as f:
safetensors_index = json.load(f)
- # For multimodal models, vision_model and multi_modal_projector are in the first shard
- all_shard_files = sorted(set(safetensors_index["weight_map"].values()))
- first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors"
-
- # Load multimodal components from the first shard file
- safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file
- print(f"Loading multimodal components from {first_shard_file}")
-
- with safe_open(safetensors_filepath, framework="pt") as f:
- shard_keys = list(f.keys())
- multimodal_keys_in_shard = [
- k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model"))
- ]
-
- if multimodal_keys_in_shard:
- print(
- f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}"
- )
- for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"):
- multimodal_state_dict[key] = f.get_tensor(key)
- else:
- print(f"No multimodal components found in {first_shard_file}")
+ # Find shards that contain multimodal keys using the weight_map
+ multimodal_shard_keys = {}
+ for key, shard_file in safetensors_index["weight_map"].items():
+ if key.startswith(("multi_modal_projector", "vision_model")):
+ multimodal_shard_keys.setdefault(shard_file, []).append(key)
+
+ for shard_file, keys in multimodal_shard_keys.items():
+ safetensors_filepath = Path(hf_checkpoint_path) / shard_file
+ print(f"Loading {len(keys)} multimodal tensors from {shard_file}")
+ with safe_open(safetensors_filepath, framework="pt") as f:
+ for key in tqdm(keys, desc=f"Loading from {shard_file}"):
+ multimodal_state_dict[key] = f.get_tensor(key)
+
+ if not multimodal_shard_keys:
+ print("No multimodal components found in weight_map")🤖 Prompt for AI Agents |
||
| else: | ||
| print(f"Warning: No safetensors files found in {hf_checkpoint_path}") | ||
|
|
||
| print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors") | ||
| return multimodal_state_dict | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq, is multimodel always in the first shard? Could it be in the first multiple shards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. We will leave this to @yueshen2016 to improve.