diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py new file mode 100644 index 000000000..e89900cbb --- /dev/null +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hugging Face checkpoint utility.""" + +import json +import os +import shutil +from pathlib import Path + +import torch +from safetensors.torch import safe_open +from tqdm import tqdm + + +def copy_remote_code( + pretrained_model_path: str | os.PathLike, + save_directory: str | os.PathLike, +): + """Copy remote code from pretrained model to save directory. + + For models that keep configuration and modeling files as part of the checkpoint, + we need to copy them to the export directory for seamless integration with inference + frameworks. + + Args: + pretrained_model_path: Path to the pretrained model. + save_directory: Path to the save directory. + + Raises: + ValueError: If the pretrained model path is not a directory. + """ + hf_checkpoint_path = Path(pretrained_model_path) + save_dir = Path(save_directory) + + if not hf_checkpoint_path.is_dir(): + raise ValueError( + f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." + ) + + for py_file in hf_checkpoint_path.glob("*.py"): + if py_file.is_file(): + shutil.copy(py_file, save_dir / py_file.name) + + +def load_multimodal_components( + pretrained_model_path: str | os.PathLike, +) -> dict[str, torch.Tensor]: + """Load multimodal components from safetensors file. + + Args: + pretrained_model_path: Path to the pretrained model. + + Returns: + A dictionary of multimodal components. + """ + hf_checkpoint_path = Path(pretrained_model_path) + if not hf_checkpoint_path.is_dir(): + raise ValueError( + f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." + ) + + safetensors_file = Path(hf_checkpoint_path) / "model.safetensors" + safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json" + + multimodal_state_dict = {} + + if safetensors_file.is_file(): + print(f"Loading multimodal components from single file: {safetensors_file}") + with safe_open(safetensors_file, framework="pt") as f: + multimodal_keys = [ + key + for key in f.keys() # noqa: SIM118 + if key.startswith(("multi_modal_projector", "vision_model")) + ] + for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"): + multimodal_state_dict[key] = f.get_tensor(key) + + elif safetensors_index_file.is_file(): + print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}") + with open(safetensors_index_file) as f: + safetensors_index = json.load(f) + + # For multimodal models, vision_model and multi_modal_projector are in the first shard + all_shard_files = sorted(set(safetensors_index["weight_map"].values())) + first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors" + + # Load multimodal components from the first shard file + safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file + print(f"Loading multimodal components from {first_shard_file}") + + with safe_open(safetensors_filepath, framework="pt") as f: + shard_keys = list(f.keys()) + multimodal_keys_in_shard = [ + k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model")) + ] + + if multimodal_keys_in_shard: + print( + f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}" + ) + for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"): + multimodal_state_dict[key] = f.get_tensor(key) + else: + print(f"No multimodal components found in {first_shard_file}") + + else: + print(f"Warning: No safetensors files found in {hf_checkpoint_path}") + + print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors") + return multimodal_state_dict diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index c269cef1d..90c523d84 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -274,6 +274,59 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike): json.dump(safetensor_index, f, indent=4) +def save_safetensors_by_layer_index( + layer_state_dicts: dict[int, dict[str, torch.Tensor]], + total_layers: int, + save_directory: str | os.PathLike, + name_template: str = "model-{:05d}-of-{:05d}", +): + """Save safetensors by layer index. + + Args: + layer_state_dicts: A dictionary of layer state dictionaries. + total_layers: Total number of layers. + save_directory: Path to the save directory. + name_template: Template for the filename. + """ + for layer_index, layer_state_dict in layer_state_dicts.items(): + filename = name_template.format(layer_index, total_layers) + meta_filename = filename + ".json" + ckpt_filename = filename + ".safetensors" + + weight_map = {} + layer_total_size = 0 + for key, val in layer_state_dict.items(): + tensor_size = val.numel() * val.element_size() + layer_total_size += tensor_size + weight_map[key] = ckpt_filename + + with open(save_directory + "/" + meta_filename, "w") as f: + json.dump( + {"metadata": {"total_size": layer_total_size}, "weight_map": weight_map}, + f, + indent=4, + ) + save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + + # [TODO]: this global barrier needs to be replaced with something safer + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + safetensor_index = { + "metadata": {"total_size": 0}, + "weight_map": {}, + } + for layer_index in range(total_layers): + meta_filename = name_template.format(layer_index + 1, total_layers) + ".json" + with open(save_directory + "/" + meta_filename) as f: + shard = json.load(f) + safetensor_index["metadata"]["total_size"] += shard["metadata"]["total_size"] + safetensor_index["weight_map"].update(shard["weight_map"]) + + with open(save_directory + "/model.safetensors.index.json", "w") as f: + json.dump(safetensor_index, f, indent=4) + + def _get_safetensors_file(pretrained_model_path: str | Path, key: str) -> Path | None: """Given a tensor key return the safetensors file that contains this tensor if exists. diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 95b194c3f..3f69271b0 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -72,7 +72,7 @@ class VllmFqGPTModelExporter(GPTModelExporter): def save_pretrained( self, save_directory: str | os.PathLike, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, ): os.makedirs(save_directory, exist_ok=True) gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory) @@ -91,7 +91,7 @@ def _get_quantization_format(self, module: torch.nn.Module): def export_mcore_gpt_to_hf_vllm_fq( model: torch.nn.Module, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 8a6d76b34..0567d0d1f 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -20,7 +20,6 @@ import json import os -import shutil import tempfile from collections import OrderedDict from pathlib import Path @@ -28,9 +27,8 @@ import torch import torch.distributed -from huggingface_hub import hf_hub_download, snapshot_download -from safetensors.torch import safe_open, save_file -from tqdm import tqdm +from huggingface_hub import hf_hub_download +from safetensors.torch import save_file from modelopt import __version__ from modelopt.torch.utils import import_plugin @@ -45,8 +43,13 @@ QUANTIZATION_NONE, QUANTIZATION_NVFP4, ) +from .plugins.hf_checkpoint_utils import copy_remote_code, load_multimodal_components from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, get_safetensor, save_safetensors +from .plugins.mcore_custom import ( + CustomModuleMapping, + get_safetensor, + save_safetensors_by_layer_index, +) from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, @@ -119,6 +122,7 @@ def __init__( raise ValueError("Input to GPTModelExport must be a megatron.core.models.GPTModel!") self._state_dict = OrderedDict() + self._layer_state_dicts = OrderedDict() self._hf_pretrained_model_name = pretrained_model_name_or_path self._hf_config = transformers.AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code @@ -221,10 +225,29 @@ def __init__( self._hf_extra_config.update(eagle_config_update) + def save_pretrained_extra_modules( + self, + save_directory: str | os.PathLike, + ): + """Save a EAGLE or Medusa checkpoints which can be deployed by vLLM and TensorRT-LLM.""" + # We use the last PP rank to write the config because + # medusa_heads and eagle_module only exist in the last stage. + pp_rank = get_pipeline_model_parallel_rank() + pp_size = get_pipeline_model_parallel_world_size() + is_last_stage_main_rank = pp_rank == pp_size - 1 + + state_dict = self.extra_state_dict + + if is_last_stage_main_rank and self._hf_extra_config is not None: + self._hf_extra_config.save_pretrained(save_directory) + save_file(state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"}) + + torch.distributed.barrier() + def save_pretrained( self, save_directory: str | os.PathLike, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, ): """Save a unified checkpoint which can be deployed by vLLM and TensorRT-LLM. @@ -242,7 +265,7 @@ def save_pretrained( is_last_stage_main_rank = pp_rank == pp_size - 1 # Main export process - state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict + layer_state_dicts = self.layer_state_dicts quantization_format = self._get_quantization_format(self.model) quantization = None @@ -259,39 +282,36 @@ def save_pretrained( # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: - if self.export_extra_modules and self._hf_extra_config is not None: - self._hf_extra_config.save_pretrained(save_directory) - else: - self._hf_config.save_pretrained(save_directory) - try: - generation_config = transformers.GenerationConfig.from_pretrained( - self._hf_pretrained_model_name - ) - generation_config.save_pretrained(save_directory) - except OSError: - pass - try: - tokenizer = transformers.AutoTokenizer.from_pretrained( - self._hf_pretrained_model_name - ) - tokenizer.save_pretrained(save_directory) - except OSError: - pass - except TypeError: - pass - try: - # Load and save preprocessor config from the original model - processor = AutoProcessor.from_pretrained( - self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code - ) - if hasattr(processor, "image_processor"): - processor.image_processor.save_pretrained(save_directory) - except (OSError, ValueError, ImportError): - pass + self._hf_config.save_pretrained(save_directory) + try: + generation_config = transformers.GenerationConfig.from_pretrained( + self._hf_pretrained_model_name + ) + generation_config.save_pretrained(save_directory) + except OSError: + pass + try: + tokenizer = transformers.AutoTokenizer.from_pretrained( + self._hf_pretrained_model_name + ) + tokenizer.save_pretrained(save_directory) + except OSError: + pass + except TypeError: + pass + try: + # Load and save preprocessor config from the original model + processor = AutoProcessor.from_pretrained( + self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code + ) + if hasattr(processor, "image_processor"): + processor.image_processor.save_pretrained(save_directory) + except (OSError, ValueError, ImportError): + pass mtp_state_dict = self._get_mtp_state_dict() if len(mtp_state_dict) > 0: - state_dict.update(mtp_state_dict) + layer_state_dicts[self.model.config.num_layers].update(mtp_state_dict) print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors") combined_exclude_modules = self._gather_exclude_modules() @@ -314,121 +334,18 @@ def save_pretrained( with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(self._hf_quant_config, f, indent=4) - if ( - is_first_stage_main_rank - and self.is_multimodal - and pretrained_model_name_or_path is not None - ): - hf_checkpoint_path = Path(pretrained_model_name_or_path) - if not hf_checkpoint_path.is_dir(): - hf_checkpoint_path = tempfile.gettempdir() + "/" + pretrained_model_name_or_path - if not Path(hf_checkpoint_path).exists(): - snapshot_download( - repo_id=pretrained_model_name_or_path, - local_dir=hf_checkpoint_path, - ) - - safetensors_file = Path(hf_checkpoint_path) / "model.safetensors" - safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json" - - multimodal_state_dict = {} - - if safetensors_file.is_file(): - print(f"Loading multimodal components from single file: {safetensors_file}") - with safe_open(safetensors_file, framework="pt") as f: - multimodal_keys = [ - key - for key in f.keys() # noqa: SIM118 - if key.startswith(("multi_modal_projector", "vision_model")) - ] - for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"): - multimodal_state_dict[key] = f.get_tensor(key) - - elif safetensors_index_file.is_file(): - print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}") - with open(safetensors_index_file) as f: - safetensors_index = json.load(f) - - # For multimodal models, vision_model and multi_modal_projector are in the first shard - all_shard_files = sorted(set(safetensors_index["weight_map"].values())) - first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors" - - # Load multimodal components from the first shard file - safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file - print(f"Loading multimodal components from {first_shard_file}") - - with safe_open(safetensors_filepath, framework="pt") as f: - shard_keys = list(f.keys()) - multimodal_keys_in_shard = [ - k - for k in shard_keys - if k.startswith(("multi_modal_projector", "vision_model")) - ] - - if multimodal_keys_in_shard: - print( - f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}" - ) - for key in tqdm( - multimodal_keys_in_shard, desc="Loading multimodal tensors" - ): - multimodal_state_dict[key] = f.get_tensor(key) - else: - print(f"No multimodal components found in {first_shard_file}") - - else: - print(f"Warning: No safetensors files found in {hf_checkpoint_path}") - - print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors") - # Add multimodal components to state_dict - state_dict.update(multimodal_state_dict) + # Add multimodal components to state_dict. Since only support decoder model quantization, + # no changes will be made to the multimodal components. We copy the multimodal components + # from the pretrained model directly to the state_dict to avoid implementing the export logic. + if is_first_stage_main_rank and self.is_multimodal: + multimodal_state_dict = load_multimodal_components(pretrained_model_name_or_path) + layer_state_dicts[0].update(multimodal_state_dict) # Barrier to ensure the export_dir has been created. torch.distributed.barrier() - if self.export_extra_modules: - if is_last_stage_main_rank: - save_file( - state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"} - ) - torch.distributed.barrier() - return - - if ( - is_last_stage_main_rank - and self._hf_config is not None - and pretrained_model_name_or_path is not None - ): - # For models that keep configuration and modeling files as part of the checkpoint, - # we need to copy them to the export directory for seamless integration with inference - # frameworks. - hf_checkpoint_path = Path(pretrained_model_name_or_path) - model_type = getattr(self._hf_config, "model_type", None) - - if hf_checkpoint_path.is_dir(): - # Local directory - files should be there - config_file = hf_checkpoint_path / f"configuration_{model_type}.py" - modeling_file = hf_checkpoint_path / f"modeling_{model_type}.py" - else: - # Remote model ID - download from HuggingFace Hub (cached automatically) - try: - config_file = hf_hub_download( - repo_id=pretrained_model_name_or_path, - filename=f"configuration_{model_type}.py", - ) - except Exception: - config_file = "" - try: - modeling_file = hf_hub_download( - repo_id=pretrained_model_name_or_path, filename=f"modeling_{model_type}.py" - ) - except Exception: - modeling_file = "" - - if config_file and os.path.exists(config_file): - shutil.copy(config_file, f"{save_directory}/configuration_{model_type}.py") - if modeling_file and os.path.exists(modeling_file): - shutil.copy(modeling_file, f"{save_directory}/modeling_{model_type}.py") + if is_last_stage_main_rank and self._hf_config is not None: + copy_remote_code(pretrained_model_name_or_path, save_directory) # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" @@ -440,7 +357,13 @@ def save_pretrained( with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) - save_safetensors(state_dict, save_directory) + # save_safetensors(state_dict, save_directory) + save_safetensors_by_layer_index( + layer_state_dicts=layer_state_dicts, + total_layers=self.model.config.num_layers, + save_directory=save_directory, + name_template="model-{:05d}-of-{:05d}", + ) @property def state_dict(self): @@ -449,6 +372,12 @@ def state_dict(self): self._get_state_dict() return self._state_dict + @property + def layer_state_dicts(self): + if len(self._layer_state_dicts) == 0: + self._get_state_dict() + return self._layer_state_dicts + @property def extra_state_dict(self): if len(self._state_dict) == 0: @@ -463,17 +392,6 @@ def _get_state_dict(self): if hasattr(model, "embedding"): self.rules["word_embeddings"](model.embedding.word_embeddings) - # Final layernorm - if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: - self.rules["final_layernorm"](model.decoder.final_layernorm) - - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: - self.rules["final_norm"](model.decoder.final_norm) - - # Output layer - if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: - self.rules["output_layer"](model.output_layer) - # Decoder layers for layer in model.decoder.layers: layer_id = layer.layer_number - 1 @@ -484,7 +402,20 @@ def _get_state_dict(self): else: raise ValueError("Only TransformerLayer or MambaLayer are supported.") - # TODO export MTP layer in the future + self._layer_state_dicts[layer.layer_number] = self._state_dict + if layer.layer_number != self.model.config.num_layers: + self._state_dict = OrderedDict() + + # Final layernorm + if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: + self.rules["final_layernorm"](model.decoder.final_layernorm) + + if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: + self.rules["final_norm"](model.decoder.final_norm) + + # Output layer + if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: + self.rules["output_layer"](model.output_layer) def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.input_layernorm, IdentityOp): @@ -761,8 +692,10 @@ def _get_quantized_state( """ name_to_value = {} qformat: str = self._get_quantization_format(module) - if qformat is None and "norm" not in prefix: # Add exclude layers for hf_quant_config - self.exclude_modules.append(prefix) + if qformat is None and "norm" not in prefix: + # Add exclude layers for hf_quant_config. Note that if the prefix is not an empty + # string then it usually ends with "." which needs to be removed. + self.exclude_modules.append(prefix.removesuffix(".")) block_size = get_weight_block_size(module) if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: @@ -1254,7 +1187,7 @@ def _gather_exclude_modules(self): def export_mcore_gpt_to_hf( model: torch.nn.Module, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), @@ -1282,7 +1215,10 @@ def export_mcore_gpt_to_hf( trust_remote_code=trust_remote_code, moe_router_dtype=moe_router_dtype, ) - exporter.save_pretrained(export_dir, pretrained_model_name_or_path) + if exporter.export_extra_modules: + exporter.save_pretrained_extra_modules(export_dir) + else: + exporter.save_pretrained(export_dir, pretrained_model_name_or_path) def import_mcore_gpt_from_hf(