From ccbc488c48bce65fc2fb77f719d346f104dacf2a Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Wed, 21 Jan 2026 23:37:34 +0000 Subject: [PATCH 01/15] Added support for HF modelopt state reload for vllm fakequant Signed-off-by: Kinjal Patel --- examples/llm_ptq/hf_ptq.py | 7 + examples/vllm_serve/README.md | 54 ++- examples/vllm_serve/fakequant_worker.py | 361 +++++++----------- examples/vllm_serve/vllm_reload_utils.py | 237 ++++++++++++ examples/vllm_serve/vllm_serve_fakequant.py | 4 +- .../torch/export/plugins/vllm_fakequant_hf.py | 23 +- .../export/plugins/vllm_fakequant_megatron.py | 62 ++- modelopt/torch/quantization/conversion.py | 2 +- .../quantization/nn/modules/quant_module.py | 18 +- 9 files changed, 519 insertions(+), 249 deletions(-) create mode 100644 examples/vllm_serve/vllm_reload_utils.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d7aadf994..3119d3457 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -51,6 +51,7 @@ import modelopt.torch.sparsity as mts from modelopt.torch.export import ( export_hf_checkpoint, + export_hf_vllm_fq_checkpoint, export_tensorrt_llm_checkpoint, get_model_type, save_expert_token_count_table, @@ -1126,6 +1127,12 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) + parser.add_argument( + "--export_vllm_fq", + help="Export vLLM fakequant checkpoint.", + default=False, + action="store_true", + ) return parser.parse_args() diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index ff0c4eea3..64310fef4 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -23,9 +23,11 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, |-----------------|--------------------------------------------------|---------------------| | QUANT_DATASET | Dataset name for calibration | cnn_dailymail | | QUANT_CALIB_SIZE| Number of samples used for calibration | 512 | -| QUANT_CFG | Quantization format | NVFP4_DEFAULT_CFG | -| KV_QUANT_CFG | Quantization format for KV Cache | None | -| AMAX_FILE_PATH | Optional path to amax file (for loading amax) | None | +| QUANT_CFG | Quantization config | None | +| KV_QUANT_CFG | KV-cache quantization config | None | +| QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None | +| MODELOPT_STATE_PATH | Optional path to exported `modelopt_state.pth` (restores ModelOpt mode + weights) | None | +| CALIB_BATCH_SIZE | Calibration batch size | 1 | Set these variables in your shell or Docker environment as needed to customize calibration. @@ -60,17 +62,49 @@ Overwrite the calibrated amax value with prepared values from either QAT/PTQ. Step 1: export the model with bf16 weights and amax values. To export the model: -- For HF model use `modelopt.torch.export.export_hf_vllm_fq_checkpoint` function. -- For MCore model use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq` function. +- For **HF** models, you can use `modelopt.torch.export.export_hf_vllm_fq_checkpoint`: -Step 2: configure from exported model using AMAX_FILE_PATH environment variable in step 1. For example: + ```python + import torch + from modelopt.torch.export import export_hf_vllm_fq_checkpoint + + with torch.inference_mode(): + export_hf_vllm_fq_checkpoint( + model, # The quantized model. + export_dir, # The directory where the exported files will be stored. + ) + ``` + Or run the example script `examples/llm_ptq/hf_ptq.py` with the `--export_vllm_fq` **flag** to export a vLLM-fakequant-compatible ModelOpt state (it generates `vllm_fq_modelopt_state.pth`, which you can use via `MODELOPT_STATE_PATH`). + +- For **MCore** models, use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq`: + + ```python + from modelopt.torch.export import export_mcore_gpt_to_hf_vllm_fq + export_mcore_gpt_to_hf_vllm_fq( + unwrapped_model, # Quantized MCore model + args.pretrained_model_name, # HF model id/path (for config/tokenizer) + export_dir=args.export_dir, # Directory where exported files will be stored + ) + + ``` + This generates `quantizer_state.pth`, which contains quantizer tensors for vLLM reload via `QUANT_FILE_PATH`. + +Step 2: use the exported artifacts when serving: + +- **HF export**: pass the exported `vllm_fq_modelopt_state.pth` via `MODELOPT_STATE_PATH` + +```bash +# HF +MODELOPT_STATE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +``` + +- **MCore export**: pass the exported `quantizer_state.pth` via `QUANT_FILE_PATH` and set `QUANT_CFG` to match the MCore quantization recipe ```bash -AMAX_FILE_PATH= QUANT_CFG= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +# MCore +QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 ``` ## Known Problems -1. AWQ is not yet supported in vLLM. -2. QAT checkpoint export doesn't have KV Cache quantization enabled. KV Cache fake quantization works for PTQ. -3. Mixed precision checkpoint doesn't work currently. +1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe66..81ea0379b 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -15,13 +15,12 @@ import dataclasses import os -import re import warnings -from collections import defaultdict from contextlib import contextmanager from typing import Any import torch +from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm from tqdm import tqdm from transformers import AutoTokenizer from vllm.sampling_params import SamplingParams @@ -29,102 +28,10 @@ from vllm.v1.worker.gpu_worker import Worker as BaseWorker import modelopt.torch.quantization as mtq +from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.utils.dataset_utils import get_dataset_dataloader -def convert_amax_hf2vllm( - hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False -) -> dict[str, torch.Tensor]: - """ - Convert amax values from HuggingFace format to vLLM format. - - This function merges: - - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) - - gate_proj, up_proj amax values into gate_up_proj (taking max) - - Args: - hf_state_dict: HuggingFace state dict containing amax values - - Returns: - vLLM format state dict with merged amax values - """ - vllm_state_dict = {} - - # Group keys by their base pattern (without the specific projection name) - merge_groups = defaultdict(list) - - for key, value in hf_state_dict.items(): - if "_amax" not in key: - # Copy non-amax keys as-is - vllm_state_dict[key] = value - continue - - # Check if this is a q/k/v projection that needs merging - qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) - if qkv_match: - base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert gate/up projection - # Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and - # model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax - expert_gate_up_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_gate_up_match: - base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is a non-expert gate/up projection that needs merging - gate_up_match = ( - "mixer" not in key - and "experts" not in key - and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) - ) - if gate_up_match: - base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert down_proj - # Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax - expert_down_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_down_match: - base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) - merge_groups[base_pattern].append((key, value)) - continue - - # Copy other amax keys as-is (like o_proj, down_proj) - vllm_state_dict[key] = value - - # Merge grouped amax values by taking the maximum - for merged_key, key_value_pairs in merge_groups.items(): - if len(key_value_pairs) > 1: - # Take the maximum across all values for this merged key - values = [value for _, value in key_value_pairs] - merged_value = torch.stack(values).max(dim=0)[0] - vllm_state_dict[merged_key] = merged_value - print(f"Merged {len(key_value_pairs)} keys into {merged_key}") - for orig_key, _ in key_value_pairs: - print(f" - {orig_key}") - else: - # Single key, just rename it - _, value = key_value_pairs[0] - vllm_state_dict[merged_key] = value - - return vllm_state_dict - - @contextmanager def disable_compilation(model): do_not_compile = True @@ -151,7 +58,9 @@ def disable_compilation(model): "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), - "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), + "quant_file_path": os.environ.get("QUANT_FILE_PATH", None), + "modelopt_state_path": os.environ.get("MODELOPT_STATE_PATH", None), + "calib_batch_size": int(os.environ.get("CALIB_BATCH_SIZE", 1)), } @@ -194,137 +103,151 @@ def _fakequant_run_prolog_worker(self) -> None: if tokenizer.pad_token != "" or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if quant_config["amax_file_path"]: - print("Will load amax, so only do a single sample calibration") - quant_config["calib_size"] = 1 - - calib_dataloader = get_dataset_dataloader( - dataset_name=quant_config["dataset"], - tokenizer=tokenizer, - batch_size=1, - num_samples=quant_config["calib_size"], - device=self.device, - ) - - def calibrate_loop(model: Any = None) -> None: - for batch_idx, batch in tqdm(enumerate(calib_dataloader)): - input_ids = batch["input_ids"][0] - - # Convert tensor to list of integers for vLLM compatibility - if torch.is_tensor(input_ids): - input_ids_list = input_ids.cpu().tolist() - else: - input_ids_list = list(input_ids) - - num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) - empty_block_ids = tuple([] for _ in range(num_groups)) - - req_id = f"req-{batch_idx}" - # Pass all possible parameters - the helper will filter based on vLLM version - new_req = _create_new_data_cls( - NewRequestData, - req_id=req_id, - prompt_token_ids=input_ids_list, - # Old API parameters - mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated - # New API parameter - mm_features=[], - sampling_params=SamplingParams(max_tokens=1), - pooling_params=None, - block_ids=empty_block_ids, - num_computed_tokens=0, - lora_request=None, - ) - - scheduler_output = _create_new_data_cls( - SchedulerOutput, - scheduled_new_reqs=[new_req], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req_id: len(input_ids_list)}, - total_num_scheduled_tokens=len(input_ids_list), - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[0] * num_groups, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - kv_connector_metadata=None, - # Old API parameters - structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated - grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated - ) - output = self.execute_model(scheduler_output) - if hasattr(self, "sample_tokens"): - if output is None: # TODO: make this default when vllm <= 0.11 is outdated - self.sample_tokens(None) - - quant_cfg = {} if quant_config["quant_cfg"] is None else getattr(mtq, quant_config["quant_cfg"]) - quant_kv_cfg = ( - {} if quant_config["kv_quant_cfg"] is None else getattr(mtq, quant_config["kv_quant_cfg"]) - ) - model = self.model_runner.model - if hasattr(model, "unwrap"): - model = model.unwrap() + if quant_config["modelopt_state_path"]: + print(f"Loading modelopt state from {quant_config['modelopt_state_path']}") + modelopt_state = torch.load(quant_config["modelopt_state_path"], weights_only=False) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) + modelopt_state = convert_modelopt_state_to_vllm(modelopt_state) + restore_from_modelopt_state(model, modelopt_state) - # Check if model has MLA and update KV config accordingly - if quant_kv_cfg: - quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + if modelopt_weights is not None: + modelopt_weights = convert_dict_to_vllm(modelopt_weights) + mtq.utils.set_quantizer_state_dict(model, modelopt_weights) - if quant_kv_cfg: - quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, quant_kv_cfg["quant_cfg"] + else: + if quant_config["quant_file_path"]: + print("Will load quant, so only do a single sample calibration") + quant_config["calib_size"] = 1 + calib_dataloader = get_dataset_dataloader( + dataset_name=quant_config["dataset"], + tokenizer=tokenizer, + batch_size=quant_config["calib_batch_size"], + num_samples=quant_config["calib_size"], + device=self.device, ) - with disable_compilation(model): - print("quantizing model...") - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - - amax_file_path = quant_config["amax_file_path"] - if amax_file_path: - print(f"Loading amax values from {amax_file_path}") - saved_amax_dict = torch.load(amax_file_path) - # convert amax keys to vLLM format - if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): - saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) - saved_amax_dict = { - key.replace("quantizer_amax", "quantizer._amax"): value - for key, value in saved_amax_dict.items() - if key.endswith("quantizer_amax") - } - saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True) - - current_state_dict = model.state_dict() - # Count amax keys in checkpoint and model - checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] - model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] - for key in checkpoint_amax_keys: - if key not in model_amax_keys: - print(f"Key {key} not found in model state dict, but exists in checkpoint") - for key in model_amax_keys: - if key not in checkpoint_amax_keys: - raise ValueError( - f"Key {key} not found in checkpoint state dict, but exists in model" + def calibrate_loop(model: Any = None) -> None: + for batch_idx, batch in tqdm(enumerate(calib_dataloader)): + input_ids = batch["input_ids"][0] + + # Convert tensor to list of integers for vLLM compatibility + if torch.is_tensor(input_ids): + input_ids_list = input_ids.cpu().tolist() + else: + input_ids_list = list(input_ids) + + num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) + empty_block_ids = tuple([] for _ in range(num_groups)) + + req_id = f"req-{batch_idx}" + # Pass all possible parameters - the helper will filter based on vLLM version + new_req = _create_new_data_cls( + NewRequestData, + req_id=req_id, + prompt_token_ids=input_ids_list, + # Old API parameters + mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated + mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated + mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated + # New API parameter + mm_features=[], + sampling_params=SamplingParams(max_tokens=1), + pooling_params=None, + block_ids=empty_block_ids, + num_computed_tokens=0, + lora_request=None, ) - checkpoint_amax_count = len(checkpoint_amax_keys) - model_amax_count = len(model_amax_keys) + scheduler_output = _create_new_data_cls( + SchedulerOutput, + scheduled_new_reqs=[new_req], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={req_id: len(input_ids_list)}, + total_num_scheduled_tokens=len(input_ids_list), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0] * num_groups, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + kv_connector_metadata=None, + # Old API parameters + structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated + grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated + ) + output = self.execute_model(scheduler_output) + if hasattr(self, "sample_tokens"): + if output is None: # TODO: make this default when vllm <= 0.11 is outdated + self.sample_tokens(None) + + quant_cfg = getattr(mtq, quant_config["quant_cfg"]) if quant_config["quant_cfg"] else {} + quant_kv_cfg = ( + getattr(mtq, quant_config["kv_quant_cfg"]) if quant_config["kv_quant_cfg"] else {} + ) + + if hasattr(model, "unwrap"): + model = model.unwrap() - # Ensure counts match - if checkpoint_amax_count != model_amax_count: - warnings.warn( - f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} " - f"amax keys but model has {model_amax_count} amax keys. This can happen if the model is using PP." + # Check if model has MLA and update KV config accordingly + if quant_kv_cfg: + quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + + if quant_kv_cfg: + quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_cfg, quant_kv_cfg["quant_cfg"] ) - # Update amax values - for key, value in saved_amax_dict.items(): - if key in current_state_dict: - current_state_dict[key] = value.to(current_state_dict[key].device) + with disable_compilation(model): + print("quantizing model...") + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + quantizer_file_path = quant_config["quant_file_path"] + if quantizer_file_path: + print(f"Loading quantizer values from {quantizer_file_path}") + saved_quant_dict = torch.load(quantizer_file_path) + # convert quant keys to vLLM format + if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): + saved_quant_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict( + saved_quant_dict + ) + saved_quant_dict = { + key.replace("quantizer_", "quantizer._"): value + for key, value in saved_quant_dict.items() + if key.endswith("quantizer_") + } + saved_quant_dict = convert_dict_to_vllm(saved_quant_dict) + + current_state_dict = model.state_dict() + # Count quant keys in checkpoint and model + checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key] + model_quant_keys = [key for key in current_state_dict if "quantizer" in key] + for key in checkpoint_quant_keys: + if key not in model_quant_keys: + print(f"Key {key} not found in model state dict, but exists in checkpoint") + for key in model_quant_keys: + if key not in checkpoint_quant_keys: + raise ValueError( + f"Key {key} not found in checkpoint state dict, but exists in model" + ) + + checkpoint_quant_count = len(checkpoint_quant_keys) + model_quant_count = len(model_quant_keys) + + # Ensure counts match + if checkpoint_quant_count != model_quant_count: + warnings.warn( + f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} " + f"quant keys but model has {model_quant_count} quantizer state keys. " + f"This can happen if the model is using PP." + ) + + # Update quant values + for key, value in saved_quant_dict.items(): + if key in current_state_dict: + current_state_dict[key] = value.to(current_state_dict[key].device) - model.load_state_dict(current_state_dict) - torch.distributed.barrier() + model.load_state_dict(current_state_dict) + torch.distributed.barrier() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) @@ -345,6 +268,10 @@ def determine_available_memory(self) -> int: return super().determine_available_memory() def compile_or_warm_up_model(self) -> None: - if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: + if ( + quant_config["quant_cfg"] + or quant_config["kv_quant_cfg"] + or quant_config["modelopt_state_path"] + ): _fakequant_run_prolog_worker(self) super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py new file mode 100644 index 000000000..9bf8e6eb4 --- /dev/null +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import torch + +from collections import defaultdict +from typing import Any, Callable + + +def _values_equal(v1: Any, v2: Any) -> bool: + """Compare values, handling dicts with tensors.""" + if isinstance(v1, dict) and isinstance(v2, dict): + if v1.keys() != v2.keys(): + return False + return all( + torch.equal(v1[k], v2[k]) if isinstance(v1[k], torch.Tensor) else v1[k] == v2[k] + for k in v1.keys() + ) + elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): + return torch.equal(v1, v2) + return v1 == v2 + + +def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, Any]: + """ + Transform a single key from HuggingFace format to vLLM format. + + Returns: + Tuple of (action, new_key_or_group, value) where action is one of: + - "copy": Copy value to new_key directly + - "group": Add to merge group identified by new_key + - "skip": Skip this key entirely + """ + if "quantizer" not in key: + return ("copy", key, value) + + # Skip softmax_quantizer (not needed in vLLM) + if "softmax_quantizer" in key: + return ("skip", None, None) + + # Skip lm_head quantizers (not needed in vLLM) + if key.startswith("lm_head.") and "quantizer" in key: + return ("skip", None, None) + + # Check if this is a q/k/v projection that needs merging + qkv_match = re.search(r"(.*\.)([qkv])_proj\.([^.]+_quantizer)(\..+)?$", key) + if qkv_match: + suffix = qkv_match.group(4) or "" + group_key = qkv_match.group(1) + "qkv_proj." + qkv_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert gate/up projection + if "mixer" not in key: + expert_gate_up_match = re.search( + r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key + ) + if expert_gate_up_match: + suffix = expert_gate_up_match.group(4) or "" + group_key = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is a non-expert gate/up projection that needs merging + if "mixer" not in key and "experts" not in key: + gate_up_match = re.search(r"(.*\.)(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key) + if gate_up_match: + suffix = gate_up_match.group(4) or "" + group_key = gate_up_match.group(1) + "gate_up_proj." + gate_up_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert down_proj + if "mixer" not in key: + expert_down_match = re.search( + r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer)(\..+)?$", key + ) + if expert_down_match: + suffix = expert_down_match.group(3) or "" + group_key = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) + suffix + return ("group", group_key, value) + + # Transform bmm_quantizer keys: self_attn.q/k/v_bmm_quantizer -> self_attn.attn.q/k/v_bmm_quantizer + bmm_match = re.search(r"(.*\.self_attn)\.([qkv]_bmm_quantizer.*)$", key) + if bmm_match: + new_key = bmm_match.group(1) + ".attn." + bmm_match.group(2) + # Debug: show device of amax values + if isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, torch.Tensor): + print(f"Renamed {key} -> {new_key}, {k} device: {v.device}") + elif isinstance(value, torch.Tensor): + print(f"Renamed {key} -> {new_key}, device: {value.device}") + else: + print(f"Renamed {key} -> {new_key}") + return ("copy", new_key, value) + + # Copy other quantizer keys as-is (like o_proj, down_proj) + return ("copy", key, value) + + +def _group_keys_for_vllm( + state_dict: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, list[tuple[str, Any]]]]: + """ + Process state dict and group keys that need merging. + + Returns: + Tuple of (direct_copy_dict, merge_groups) + """ + vllm_state_dict = {} + merge_groups = defaultdict(list) + + for key, value in state_dict.items(): + action, new_key, new_value = _convert_key_for_vllm(key, value) + + if action == "copy": + vllm_state_dict[new_key] = new_value + elif action == "group": + merge_groups[new_key].append((key, new_value)) + # action == "skip" does nothing + + return vllm_state_dict, merge_groups + + +def _merge_values_by_max_or_concat( + merged_key: str, key_value_pairs: list[tuple[str, Any]] +) -> Any: + """ + Merge values by taking max for amax, concatenating for others. + Used for quantizer state weights (tensor values). + """ + values = [value for _, value in key_value_pairs] + + # Check if values are dicts (OrderedDict) containing tensors + if isinstance(values[0], dict): + merged_value = {} + for dict_key in values[0].keys(): + tensors = [v[dict_key] for v in values] + if "_amax" in dict_key: + merged_value[dict_key] = torch.stack(tensors).max(dim=0)[0] + else: + merged_value[dict_key] = torch.cat(tensors, dim=0) + return merged_value + else: + # Values are tensors directly + if "_amax" in merged_key: + merged_value = torch.stack(values).max(dim=0)[0] + else: + merged_value = torch.cat(values, dim=0) + return merged_value + + +def _merge_values_require_identical( + merged_key: str, key_value_pairs: list[tuple[str, Any]] +) -> Any: + """ + Merge values by requiring all values to be identical. + Used for quantizer state (config/metadata). + """ + keys = [k for k, _ in key_value_pairs] + values = [v for _, v in key_value_pairs] + first_value = values[0] + + for i, val in enumerate(values[1:], start=1): + if not _values_equal(val, first_value): + raise ValueError( + f"Cannot merge keys into '{merged_key}': values differ.\n" + f" '{keys[0]}' has value: {first_value}\n" + f" '{keys[i]}' has value: {val}" + ) + return first_value + + +def convert_dict_to_vllm( + state_dict: dict[str, Any], + merge_mode: str = "max_or_concat" +) -> dict[str, Any]: + """ + Common implementation for converting quantizer state from HF to vLLM format. + + Args: + state_dict: Input state dict + fuse_experts: Whether to fuse expert projections + merge_mode: Mode to merge grouped values, "max_or_concat" or "require_identical" + """ + vllm_state_dict, merge_groups = _group_keys_for_vllm(state_dict) + + merge_fn = _merge_values_require_identical if merge_mode == "require_identical" else _merge_values_by_max_or_concat + + # Merge grouped values + for merged_key, key_value_pairs in merge_groups.items(): + if len(key_value_pairs) > 1: + merged_value = merge_fn(merged_key, key_value_pairs) + vllm_state_dict[merged_key] = merged_value + else: + # Single key, just rename it + _, value = key_value_pairs[0] + vllm_state_dict[merged_key] = value + + return vllm_state_dict + + +def convert_modelopt_state_to_vllm(modelopt_state: dict[str, Any]) -> dict[str, Any]: + """ + Convert modelopt state from HuggingFace format to vLLM compatible format. + + This function converts the quantizer state from HuggingFace format to vLLM compatible format. + + Args: + modelopt_state: HuggingFace modelopt state dict + + Returns: + vLLM compatible modelopt state dict + """ + modelopt_state_dict = modelopt_state.pop("modelopt_state_dict", []) + for idx, current_mode in enumerate(modelopt_state_dict): + current_mode_metadata = current_mode[1].pop("metadata", {}) + current_mode_quant_state = current_mode_metadata.pop("quantizer_state", {}) + if current_mode_quant_state: + current_mode_metadata["quantizer_state"] = convert_dict_to_vllm(current_mode_quant_state, merge_mode="require_identical") + else: + current_mode_metadata.pop("quantizer_state", None) + current_mode[1]['metadata'] = current_mode_metadata + modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) + modelopt_state["modelopt_state_dict"] = modelopt_state_dict + return modelopt_state diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index 25483f2be..c32593005 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -74,8 +74,10 @@ "QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", - "AMAX_FILE_PATH", + "QUANT_FILE_PATH", "KV_QUANT_CFG", + "MODELOPT_STATE_PATH", + "CALIB_BATCH_SIZE", } RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 54987b40c..03b191346 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -19,7 +19,8 @@ import torch import torch.nn as nn -from modelopt.torch.export.layer_utils import is_quantlinear +import modelopt.torch.opt as mto +from modelopt.torch.export.layer_utils import is_attention, is_quantlinear from modelopt.torch.quantization.utils import get_quantizer_state_dict __all__ = ["export_hf_vllm_fq_checkpoint"] @@ -44,12 +45,11 @@ def export_hf_vllm_fq_checkpoint( export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - amax_dict = { - name + "._amax": param["_amax"].detach().clone().cpu() - for name, param in get_quantizer_state_dict(model).items() - if "_amax" in param - } + quantizer_state_dict = get_quantizer_state_dict(model) + modelopt_state = mto.modelopt_state(model) + modelopt_state["modelopt_state_weights"] = quantizer_state_dict + torch.save(modelopt_state, f"{export_dir}/vllm_fq_modelopt_state.pth") # remove quantizer from model for _, module in model.named_modules(): if is_quantlinear(module): @@ -57,6 +57,15 @@ def export_hf_vllm_fq_checkpoint( if hasattr(module, attr): delattr(module, attr) module.export() - torch.save(amax_dict, f"{export_dir}/quant_amax.pth") + if is_attention(module): + for attr in [ + "q_bmm_quantizer", + "k_bmm_quantizer", + "v_bmm_quantizer", + "softmax_quantizer", + ]: + if hasattr(module, attr): + delattr(module, attr) + # Save model model.save_pretrained(export_dir, state_dict=model.state_dict(), save_modelopt_state=False) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 3f69271b0..1fef361f0 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -22,6 +22,7 @@ from modelopt.torch.export.model_config import QUANTIZATION_NONE from modelopt.torch.export.unified_export_megatron import GPTModelExporter +from modelopt.torch.quantization.utils import get_quantizer_state_dict __all__ = ["export_mcore_gpt_to_hf_vllm_fq"] @@ -38,8 +39,8 @@ def gather_mcore_vllm_fq_quantized_state_dict( Returns: The state dictionary of the module without quantized state. """ - amax_state_dict = { - k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax") + quantizer_state_dict = { + k: v.detach().clone().cpu() for k, v in state_dict.items() if "quantizer" in k } # Gather all amax dicts to rank 0 @@ -48,20 +49,19 @@ def gather_mcore_vllm_fq_quantized_state_dict( if rank == 0: # Rank 0 will collect all amax values - all_amax_dicts = [None] * world_size - torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0) + all_quantizer_state_dicts = [None] * world_size + torch.distributed.gather_object(quantizer_state_dict, all_quantizer_state_dicts, dst=0) - # Merge all amax dicts into one - merged_amax_dict = {} - for amax_dict in all_amax_dicts: - if amax_dict is not None: - merged_amax_dict.update(amax_dict) + # Merge all quantizer state dicts into one + merged_quantizer_state_dict = {} + for quantizer_state_dict in all_quantizer_state_dicts: + if quantizer_state_dict is not None: + merged_quantizer_state_dict.update(quantizer_state_dict) - print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}") - torch.save(merged_amax_dict, save_directory + "/quant_amax.pth") + torch.save(merged_quantizer_state_dict, save_directory + "/quantizer_state.pth") else: # Other ranks just send their amax values - torch.distributed.gather_object(amax_state_dict, None, dst=0) + torch.distributed.gather_object(quantizer_state_dict, None, dst=0) torch.distributed.barrier() @@ -76,6 +76,13 @@ def save_pretrained( ): os.makedirs(save_directory, exist_ok=True) gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory) + + # NOTE: `self.state_dict` is an OrderedDict; mutating it while iterating + # over its keys raises "OrderedDict mutated during iteration". + keys_to_remove = [k for k in self.state_dict if "quantizer" in k] + for k in keys_to_remove: + self.state_dict.pop(k, None) + assert not (self.is_multimodal and pretrained_model_name_or_path is not None), ( "Exporting weights in bf16 and amax values is not supported for multimodal models " "when pretrained_model_name_or_path is not None" @@ -88,6 +95,37 @@ def save_pretrained( def _get_quantization_format(self, module: torch.nn.Module): return QUANTIZATION_NONE + def _get_quantized_state( + self, + module: torch.nn.Module, + dtype: torch.dtype = torch.float16, + ) -> tuple[dict[str, torch.Tensor], str, int]: + """Return a state_dict, quantization format, and block_size of the module. + + Args: + module: The target module to perform real quantization. + dtype: The default data type. + + Returns: + Tuple: state_dict, quantization format, and block_size of the module. + """ + name_to_value = {} + qformat: str = self._get_quantization_format(module) + block_size = 0 + + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight.to(dtype).cpu() + name_to_value["weight"] = weight + else: + return name_to_value, qformat, block_size + + if hasattr(module, "bias") and module.bias is not None: + name_to_value["bias"] = module.bias.to(dtype).cpu() + for name, param in get_quantizer_state_dict(module).items(): + for key, value in param.items(): + name_to_value[name + "." + key] = value.to(dtype).cpu() + return name_to_value, qformat, block_size + def export_mcore_gpt_to_hf_vllm_fq( model: torch.nn.Module, diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index f7ef704ee..cffbda4dc 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -137,7 +137,7 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: for name, module in model.named_modules(): if isinstance(module, QuantModule): name = get_unwrapped_name(name, model) - module.modelopt_post_restore(name) + module.modelopt_post_restore(name, model=model) return model diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index a792b6429..9c7cff54a 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -88,7 +88,7 @@ def _initialize_parallel_state(self): self.parallel_state = ParallelState(data_parallel_group=None) - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", model: "torch.nn.Module | None" = None): """Post-restore to correctly configure the TensorQuantizer states. TensorQuantizer states are restored to their shape before saving. Now we need to further configure them. @@ -97,6 +97,10 @@ def modelopt_post_restore(self, prefix: str = ""): 2. For sharded modules the restored states of TensorQuantizer could be incorrect. This is because parallelism such as TP might have been changed between saving and resoring. So we need to re-calculate the state shapes. Hence such modules should override this and implement their own logic. + + Args: + prefix: The module name prefix for error messages. + model: Optional main model to search for device if not found in this module. """ # Get a parameter or buffer that does not belong to a TensorQuantizer non_tq_param_or_buffer = None @@ -106,6 +110,18 @@ def modelopt_post_restore(self, prefix: str = ""): non_tq_param_or_buffer = param_or_buffer break + # If not found (e.g., container modules like vLLM's attn that only have child quantizers), + # traverse up to parent's parent to find a module with parameters + if model is not None: + parts = prefix.split(".") + parent_prefix = ".".join(parts[: len(parts) - 1]) + parent_module = model.get_submodule(parent_prefix) + # Look for any parameter in parent module (not just state_dict) + for param in parent_module.parameters(): + # Skip if param belongs to a TensorQuantizer + non_tq_param_or_buffer = param + break + if non_tq_param_or_buffer is None: warnings.warn( f"Could not identify the device for TensorQuantizer states of {prefix}. " From 14f7b8df30fa95a5141c12f6b2dffab12a5b22b5 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Wed, 21 Jan 2026 23:39:44 +0000 Subject: [PATCH 02/15] changelog update Signed-off-by: Kinjal Patel --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 744238656..6a14c0df6 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,7 @@ NVIDIA Model Optimizer Changelog (Linux) - User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow. - ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory. - Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. +- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. 0.42 (2026-02-xx) ^^^^^^^^^^^^^^^^^ From 2206212011d1d07df67d417eb92216bcb9d97367 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 22 Jan 2026 01:16:37 +0000 Subject: [PATCH 03/15] minor Signed-off-by: Kinjal Patel --- examples/vllm_serve/README.md | 2 + examples/vllm_serve/fakequant_worker.py | 2 +- examples/vllm_serve/vllm_reload_utils.py | 57 +++++++++++++----------- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 64310fef4..74a1f2510 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -74,6 +74,7 @@ Step 1: export the model with bf16 weights and amax values. To export the model: export_dir, # The directory where the exported files will be stored. ) ``` + Or run the example script `examples/llm_ptq/hf_ptq.py` with the `--export_vllm_fq` **flag** to export a vLLM-fakequant-compatible ModelOpt state (it generates `vllm_fq_modelopt_state.pth`, which you can use via `MODELOPT_STATE_PATH`). - For **MCore** models, use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq`: @@ -87,6 +88,7 @@ Step 1: export the model with bf16 weights and amax values. To export the model: ) ``` + This generates `quantizer_state.pth`, which contains quantizer tensors for vLLM reload via `QUANT_FILE_PATH`. Step 2: use the exported artifacts when serving: diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 81ea0379b..a5f1b0332 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -20,12 +20,12 @@ from typing import Any import torch -from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm from tqdm import tqdm from transformers import AutoTokenizer from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.gpu_worker import Worker as BaseWorker +from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm import modelopt.torch.quantization as mtq from modelopt.torch.opt.conversion import restore_from_modelopt_state diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py index 9bf8e6eb4..bbd77be5b 100644 --- a/examples/vllm_serve/vllm_reload_utils.py +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -14,10 +14,10 @@ # limitations under the License. import re -import torch - from collections import defaultdict -from typing import Any, Callable +from typing import Any + +import torch def _values_equal(v1: Any, v2: Any) -> bool: @@ -27,14 +27,14 @@ def _values_equal(v1: Any, v2: Any) -> bool: return False return all( torch.equal(v1[k], v2[k]) if isinstance(v1[k], torch.Tensor) else v1[k] == v2[k] - for k in v1.keys() + for k in v1 ) elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): return torch.equal(v1, v2) return v1 == v2 -def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, Any]: +def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: """ Transform a single key from HuggingFace format to vLLM format. @@ -47,12 +47,8 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, if "quantizer" not in key: return ("copy", key, value) - # Skip softmax_quantizer (not needed in vLLM) - if "softmax_quantizer" in key: - return ("skip", None, None) - - # Skip lm_head quantizers (not needed in vLLM) - if key.startswith("lm_head.") and "quantizer" in key: + # Skip softmax_quantizer and lm_head quantizers(not needed in vLLM) + if "softmax_quantizer" in key or (key.startswith("lm_head.") and "quantizer" in key): return ("skip", None, None) # Check if this is a q/k/v projection that needs merging @@ -69,7 +65,9 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, ) if expert_gate_up_match: suffix = expert_gate_up_match.group(4) or "" - group_key = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + group_key = ( + expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + ) return ("group", group_key, value) # Check if this is a non-expert gate/up projection that needs merging @@ -110,8 +108,8 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, def _group_keys_for_vllm( - state_dict: dict[str, Any] -) -> tuple[dict[str, Any], dict[str, list[tuple[str, Any]]]]: + state_dict: dict[str, Any], +) -> tuple[dict[str, Any], defaultdict[str, list[tuple[str, Any]]]]: """ Process state dict and group keys that need merging. @@ -123,7 +121,11 @@ def _group_keys_for_vllm( for key, value in state_dict.items(): action, new_key, new_value = _convert_key_for_vllm(key, value) - + if new_key is None or new_value is None: + assert action == "skip", ( + f"Expected action to be 'skip' for key {key}, value {value}, got {action}" + ) + continue if action == "copy": vllm_state_dict[new_key] = new_value elif action == "group": @@ -133,9 +135,7 @@ def _group_keys_for_vllm( return vllm_state_dict, merge_groups -def _merge_values_by_max_or_concat( - merged_key: str, key_value_pairs: list[tuple[str, Any]] -) -> Any: +def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: """ Merge values by taking max for amax, concatenating for others. Used for quantizer state weights (tensor values). @@ -145,7 +145,7 @@ def _merge_values_by_max_or_concat( # Check if values are dicts (OrderedDict) containing tensors if isinstance(values[0], dict): merged_value = {} - for dict_key in values[0].keys(): + for dict_key in values[0]: tensors = [v[dict_key] for v in values] if "_amax" in dict_key: merged_value[dict_key] = torch.stack(tensors).max(dim=0)[0] @@ -161,9 +161,7 @@ def _merge_values_by_max_or_concat( return merged_value -def _merge_values_require_identical( - merged_key: str, key_value_pairs: list[tuple[str, Any]] -) -> Any: +def _merge_values_require_identical(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: """ Merge values by requiring all values to be identical. Used for quantizer state (config/metadata). @@ -183,8 +181,7 @@ def _merge_values_require_identical( def convert_dict_to_vllm( - state_dict: dict[str, Any], - merge_mode: str = "max_or_concat" + state_dict: dict[str, Any], merge_mode: str = "max_or_concat" ) -> dict[str, Any]: """ Common implementation for converting quantizer state from HF to vLLM format. @@ -196,7 +193,11 @@ def convert_dict_to_vllm( """ vllm_state_dict, merge_groups = _group_keys_for_vllm(state_dict) - merge_fn = _merge_values_require_identical if merge_mode == "require_identical" else _merge_values_by_max_or_concat + merge_fn = ( + _merge_values_require_identical + if merge_mode == "require_identical" + else _merge_values_by_max_or_concat + ) # Merge grouped values for merged_key, key_value_pairs in merge_groups.items(): @@ -228,10 +229,12 @@ def convert_modelopt_state_to_vllm(modelopt_state: dict[str, Any]) -> dict[str, current_mode_metadata = current_mode[1].pop("metadata", {}) current_mode_quant_state = current_mode_metadata.pop("quantizer_state", {}) if current_mode_quant_state: - current_mode_metadata["quantizer_state"] = convert_dict_to_vllm(current_mode_quant_state, merge_mode="require_identical") + current_mode_metadata["quantizer_state"] = convert_dict_to_vllm( + current_mode_quant_state, merge_mode="require_identical" + ) else: current_mode_metadata.pop("quantizer_state", None) - current_mode[1]['metadata'] = current_mode_metadata + current_mode[1]["metadata"] = current_mode_metadata modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) modelopt_state["modelopt_state_dict"] = modelopt_state_dict return modelopt_state From a1f7b5dc05f7e5fae9431527ff044cf70c8fbf22 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 22 Jan 2026 21:58:00 +0000 Subject: [PATCH 04/15] updated for TP>1 Signed-off-by: Kinjal Patel --- examples/vllm_serve/README.md | 4 +- examples/vllm_serve/fakequant_worker.py | 23 +++++++++-- examples/vllm_serve/vllm_reload_utils.py | 50 +++++++++++++++++++----- 3 files changed, 63 insertions(+), 14 deletions(-) diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 74a1f2510..60002b747 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -26,7 +26,7 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, | QUANT_CFG | Quantization config | None | | KV_QUANT_CFG | KV-cache quantization config | None | | QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None | -| MODELOPT_STATE_PATH | Optional path to exported `modelopt_state.pth` (restores ModelOpt mode + weights) | None | +| MODELOPT_STATE_PATH | Optional path to exported `vllm_fq_modelopt_state.pth` (restores quantizer state and parameters) | None | | CALIB_BATCH_SIZE | Calibration batch size | 1 | Set these variables in your shell or Docker environment as needed to customize calibration. @@ -110,3 +110,5 @@ QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fa ## Known Problems 1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). +2. AWQ reload is not supported yet +3. KV cache quantization export and reload is not supported in MCore yet. diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index a5f1b0332..fe822c877 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -25,7 +25,11 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.gpu_worker import Worker as BaseWorker -from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm +from vllm_reload_utils import ( + convert_dict_to_vllm, + convert_modelopt_state_to_vllm, + process_state_dict_for_tp, +) import modelopt.torch.quantization as mtq from modelopt.torch.opt.conversion import restore_from_modelopt_state @@ -106,7 +110,11 @@ def _fakequant_run_prolog_worker(self) -> None: model = self.model_runner.model if quant_config["modelopt_state_path"]: print(f"Loading modelopt state from {quant_config['modelopt_state_path']}") - modelopt_state = torch.load(quant_config["modelopt_state_path"], weights_only=False) + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + modelopt_state = torch.load( + quant_config["modelopt_state_path"], weights_only=False, map_location="cpu" + ) modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) modelopt_state = convert_modelopt_state_to_vllm(modelopt_state) restore_from_modelopt_state(model, modelopt_state) @@ -203,8 +211,11 @@ def calibrate_loop(model: Any = None) -> None: quantizer_file_path = quant_config["quant_file_path"] if quantizer_file_path: + self.model_runner._dummy_run(1) print(f"Loading quantizer values from {quantizer_file_path}") - saved_quant_dict = torch.load(quantizer_file_path) + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + saved_quant_dict = torch.load(quantizer_file_path, map_location="cpu") # convert quant keys to vLLM format if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): saved_quant_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict( @@ -242,12 +253,16 @@ def calibrate_loop(model: Any = None) -> None: ) # Update quant values + saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict) for key, value in saved_quant_dict.items(): if key in current_state_dict: current_state_dict[key] = value.to(current_state_dict[key].device) model.load_state_dict(current_state_dict) - torch.distributed.barrier() + + # Only barrier if distributed is actually initialized (avoids deadlocks). + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.barrier() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py index bbd77be5b..a72cb0b3c 100644 --- a/examples/vllm_serve/vllm_reload_utils.py +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -18,6 +18,7 @@ from typing import Any import torch +from vllm.distributed.parallel_state import get_tp_group def _values_equal(v1: Any, v2: Any) -> bool: @@ -92,15 +93,6 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: bmm_match = re.search(r"(.*\.self_attn)\.([qkv]_bmm_quantizer.*)$", key) if bmm_match: new_key = bmm_match.group(1) + ".attn." + bmm_match.group(2) - # Debug: show device of amax values - if isinstance(value, dict): - for k, v in value.items(): - if isinstance(v, torch.Tensor): - print(f"Renamed {key} -> {new_key}, {k} device: {v.device}") - elif isinstance(value, torch.Tensor): - print(f"Renamed {key} -> {new_key}, device: {value.device}") - else: - print(f"Renamed {key} -> {new_key}") return ("copy", new_key, value) # Copy other quantizer keys as-is (like o_proj, down_proj) @@ -238,3 +230,43 @@ def convert_modelopt_state_to_vllm(modelopt_state: dict[str, Any]) -> dict[str, modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) modelopt_state["modelopt_state_dict"] = modelopt_state_dict return modelopt_state + + +def process_state_dict_for_tp(saved_qstate_dict, current_state_dict): + """Shard quantizer tensors for tensor parallelism by matching expected shapes.""" + tp_group = get_tp_group() + tp_rank = tp_group.rank_in_group + tp_world_size = tp_group.world_size + + result = {} + for key, value in saved_qstate_dict.items(): + if key in current_state_dict: + expected_shape = current_state_dict[key].shape + if value.shape != expected_shape: + # Find the dimension that was tensor-parallel sharded. + # We expect exactly one dimension to satisfy: + # checkpoint_dim == expected_dim * tp_world_size + shard_dims = [ + d + for d in range(len(expected_shape)) + if value.shape[d] == expected_shape[d] * tp_world_size + ] + if len(shard_dims) != 1: + raise ValueError( + f"Cannot infer TP shard dim for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value.shape)}, " + ) + + shard_dim = shard_dims[0] + shard_size = expected_shape[shard_dim] + start = tp_rank * shard_size + end = start + shard_size + if end > value.shape[shard_dim]: + raise ValueError( + f"TP shard out of bounds for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value.shape)})" + ) + value = value.narrow(shard_dim, start, shard_size).contiguous() + result[key] = value + + return result From 625848d9792e7d1d7cba7166b55f9a12cff515f9 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 22 Jan 2026 23:33:06 +0000 Subject: [PATCH 05/15] minor Signed-off-by: Kinjal Patel --- .../torch/quantization/nn/modules/quant_module.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 9c7cff54a..e516941b7 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -112,15 +112,18 @@ def modelopt_post_restore(self, prefix: str = "", model: "torch.nn.Module | None # If not found (e.g., container modules like vLLM's attn that only have child quantizers), # traverse up to parent's parent to find a module with parameters - if model is not None: + if non_tq_param_or_buffer is None and model is not None: parts = prefix.split(".") parent_prefix = ".".join(parts[: len(parts) - 1]) - parent_module = model.get_submodule(parent_prefix) + parent_module = model.get_submodule(parent_prefix) if parent_prefix else model # Look for any parameter in parent module (not just state_dict) - for param in parent_module.parameters(): - # Skip if param belongs to a TensorQuantizer - non_tq_param_or_buffer = param - break + for name, param in parent_module.named_parameters(): + # Skip params that belong to TensorQuantizer submodules + param_parent_name = name.rsplit(".", 1)[0] if "." in name else "" + param_parent = parent_module.get_submodule(param_parent_name) + if not isinstance(param_parent, TensorQuantizer): + non_tq_param_or_buffer = param + break if non_tq_param_or_buffer is None: warnings.warn( From 89a74e78dba71f563bbb0d882ed7d4c2237b2457 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 04:26:57 +0000 Subject: [PATCH 06/15] updated test Signed-off-by: Kinjal Patel --- .../export/test_vllm_fakequant_hf_export.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py index a156ad126..2c2c44ef2 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py +++ b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py @@ -48,7 +48,7 @@ def forward_loop(model): model(input_ids) model = mtq.quantize(model, quant_cfg, forward_loop) - + quantizer_state_dict_before = mtq.utils.get_quantizer_state_dict(model) model_state_dict = deepcopy(model.state_dict()) # Export directory @@ -59,8 +59,10 @@ def forward_loop(model): export_hf_vllm_fq_checkpoint(model, export_dir=export_dir) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + modelopt_state_file = export_dir / "vllm_fq_modelopt_state.pth" + assert modelopt_state_file.exists(), ( + f"vllm_fq_modelopt_state.pth file should be created in {export_dir}" + ) # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json" @@ -73,21 +75,19 @@ def forward_loop(model): model_after = model_after.cuda() model_after.eval() model_after_state_dict = model_after.state_dict() - amax_state_dict = {} for key, param in model_state_dict.items(): - if key.endswith("_amax"): - amax_state_dict[key] = param - continue - - assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( - f"Weight mismatch for {key}: " - f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " - f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" - ) - - # Verify amax values are correct - amax_dict = torch.load(quant_amax_file) - assert len(amax_dict) > 0, "amax_dict should not be empty" - assert amax_dict.keys() == amax_state_dict.keys(), ( - "amax keys mismatch between before and after export" + if "quantizer" not in key: + assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( + f"Weight mismatch for {key}: " + f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " + f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" + ) + + # Verify quantizer state dict values are correct + quantizer_state_dict = torch.load(modelopt_state_file)["modelopt_state_weights"] + assert len(quantizer_state_dict) > 0, ( + f"modelopt_state_weights should not be empty in {modelopt_state_file}" + ) + assert quantizer_state_dict.keys() == quantizer_state_dict_before.keys(), ( + "quantizer state dict keys mismatch between before and after export" ) From b3b53b5e9a79b734dd968c7fd4291c4cca7198cb Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 15:41:56 +0000 Subject: [PATCH 07/15] test fix Signed-off-by: Kinjal Patel --- .../torch/export/test_vllm_fakequant_megatron_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py index 8e4578d7b..6336dce68 100644 --- a/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py +++ b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py @@ -96,8 +96,8 @@ def forward_loop(model): ) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + quant_amax_file = export_dir / "quantizer_state.pth" + assert quant_amax_file.exists(), f"quantizer_state.pth file should be created in {export_dir}" # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json" From 6390599ab07b778ceedc1eeed1d58163bc1841fd Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 18:10:03 +0000 Subject: [PATCH 08/15] minor Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/plugins/custom.py | 2 +- modelopt/torch/quantization/plugins/megatron.py | 4 ++-- modelopt/torch/quantization/plugins/transformer_engine.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 4200aadc7..09a91796d 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -114,7 +114,7 @@ def _setup(self): # the dtype can change later. self.original_weight_dtype = None if self.weight is None else self.weight.dtype - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index e84735ae9..93c3e72ec 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -471,7 +471,7 @@ def _get_shard_axis_dict(self, state_dict): shard_axis_dict[k] = self._scale_tensor_shard_axis return shard_axis_dict - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): """Post restore to correctly configure the realquant scales. ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their @@ -732,7 +732,7 @@ def forward(self, query, key, value, *args, **kwargs): value = self.v_bmm_quantizer(value) return super().forward(query, key, value, *args, **kwargs) - def modelopt_post_restore(self, name=""): + def modelopt_post_restore(self, name="", *args, **kwargs): """Restore quantizer states after model loading.""" for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]: # TODO: Add support for non-scalar states such as diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index afc08211f..dbeb05a65 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -141,7 +141,7 @@ def _setup(self): # TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization # with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm. - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning # self.weight0 to self.weight to run the quantizer states initialization. From ce865716bedef7d52f50f3ca36648ab0494ef4f7 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:35:33 +0000 Subject: [PATCH 09/15] created seperate script for vllm fq export Signed-off-by: Kinjal Patel --- examples/vllm_serve/hf_ptq_export.py | 314 ++++++++++++++++ examples/vllm_serve/vllm_fq_export.py | 337 ++++++++++++++++++ .../torch/export/plugins/vllm_fakequant_hf.py | 25 +- 3 files changed, 674 insertions(+), 2 deletions(-) create mode 100644 examples/vllm_serve/hf_ptq_export.py create mode 100644 examples/vllm_serve/vllm_fq_export.py diff --git a/examples/vllm_serve/hf_ptq_export.py b/examples/vllm_serve/hf_ptq_export.py new file mode 100644 index 000000000..7ee5c091d --- /dev/null +++ b/examples/vllm_serve/hf_ptq_export.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import warnings + +import numpy as np +import torch +import transformers +from accelerate import infer_auto_device_map, init_empty_weights +from accelerate.utils import get_max_memory +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.torch.export import export_hf_vllm_fq_checkpoint +from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, + get_max_batch_size, + get_supported_datasets, +) +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +mto.enable_huggingface_checkpointing() + + +def load_model( + ckpt_path, + device="cuda", + gpu_mem_percentage=0.8, + trust_remote_code=False, + use_seq_device_map=False, +): + print(f"Initializing model from {ckpt_path}") + + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} + try: + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + except Exception as e: + raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e + + # Pick the transformers model class to load. + architecture = hf_config.architectures[0] + use_auto_causallm = (not hasattr(transformers, architecture)) or ("Deepseek" in architecture) + if use_auto_causallm: + if not hasattr(transformers, architecture): + warnings.warn( + f"Architecture {architecture} not found in transformers: {transformers.__version__}. " + "Falling back to AutoModelForCausalLM." + ) + assert trust_remote_code, ( + "Please set trust_remote_code=True if you want to use this architecture" + ) + model_cls = AutoModelForCausalLM + from_config = model_cls.from_config + else: + model_cls = getattr(transformers, architecture) + from_config = model_cls._from_config + + # Decide device_map and optional memory cap. + if device == "cpu": + device_map = "cpu" + elif use_seq_device_map: + device_map = "sequential" + else: + device_map = "auto" + + model_kwargs: dict[str, object] = dict(config_kwargs) + if device_map == "sequential": + max_memory = get_max_memory() + model_kwargs["max_memory"] = {k: v * gpu_mem_percentage for k, v in max_memory.items()} + + # Detect if the model would offload to CPU; if so, cap GPU memory for calibration. + with init_empty_weights(): + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + empty_kwargs: dict[str, object] = dict(model_kwargs, torch_dtype=torch_dtype) + empty_kwargs.pop("max_memory", None) # only used by from_pretrained dispatch + if model_cls is not AutoModelForCausalLM: + empty_kwargs.pop("trust_remote_code", None) + empty_model = from_config(hf_config, **empty_kwargs) + + max_memory = get_max_memory() + inferred_device_map = infer_auto_device_map(empty_model, max_memory=max_memory) + if "cpu" in inferred_device_map.values(): + for dev_id, mem in list(max_memory.items()): + if isinstance(dev_id, int): + max_memory[dev_id] = mem * gpu_mem_percentage + print( + "Model does not fit to the GPU mem. " + f"We apply the following memory limit for calibration: \n{max_memory}\n" + "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " + "reduce the calibration `batch_size` manually." + ) + model_kwargs["max_memory"] = max_memory + + model = model_cls.from_pretrained(ckpt_path, device_map=device_map, **model_kwargs) + model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + + if device == "cuda" and not is_model_on_gpu(model): + print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") + + return model + + +def is_model_on_gpu(model) -> bool: + """Returns if the model is fully loaded on GPUs.""" + return all("cuda" in str(param.device) for param in model.parameters()) + +def get_tokenizer(ckpt_path, trust_remote_code=False): + """Returns the tokenizer from the model ckpt_path.""" + print(f"Initializing tokenizer from {ckpt_path}") + tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, + padding_side="left", + trust_remote_code=trust_remote_code, + ) + + # can't set attribute 'pad_token' for "" + if tokenizer.pad_token != "": + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + +def quantize_and_export_model( + args: argparse.Namespace, +): + model = load_model( + args.pyt_ckpt_path, + device=args.device, + gpu_mem_percentage=args.gpu_max_mem_percentage, + trust_remote_code=args.trust_remote_code, + use_seq_device_map=args.use_seq_device_map, + ) + + if args.batch_size == 0: + args.batch_size = get_max_batch_size( + model, + max_sample_length=args.calib_seq, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + device = model.device + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + device=device, + include_labels=False, + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + mtq_cfg = getattr(mtq, args.quant_cfg) + if args.kv_cache_quant_cfg is not None: + kv_cache_quant_cfg = getattr(mtq, args.kv_cache_quant_cfg) + mtq_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + mtq_cfg, kv_cache_quant_cfg["quant_cfg"] + ) + input_ids = next(iter(calib_dataloader))["input_ids"][0:1] + model_is_already_quantized = is_quantized(model) + if not model_is_already_quantized: + generated_str_before_ptq = tokenizer.decode(model.generate(input_ids)[0]) + quantized_model = mtq.quantize(model, mtq_cfg, calibrate_loop) + generated_str_after_ptq = tokenizer.decode(model.generate(input_ids)[0]) + else: + print("Model is already quantized, Skipping quantization...") + quantized_model = model + + mtq.print_quant_summary(quantized_model) + if not model_is_already_quantized: + print("--------") + print(f"example test input: {tokenizer.decode(input_ids[0])}") + print("--------") + print(f"example outputs before ptq: {generated_str_before_ptq}") + print("--------") + print(f"example outputs after ptq: {generated_str_after_ptq}") + + export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) + # from modelopt.torch.quantization.utils import get_quantizer_state_dict + # quantized_model.save_pretrained(args.export_path, state_dict=quantized_model.state_dict(), save_modelopt_state=False) + # modelopt_state = mto.modelopt_state(quantized_model) + # modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(quantized_model) + # torch.save(modelopt_state, f"{args.export_path}/modelopt_state.pth") + tokenizer.save_pretrained(args.export_path) + print(f"Model exported to {args.export_path}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--pyt_ckpt_path", + help="Specify where the PyTorch checkpoint path is", + required=True, + ) + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--quant_cfg", + help="Quantization configuration.", + default="FP8_DEFAULT_CFG", + ) + parser.add_argument( + "--batch_size", + help="Batch size for calibration. Default to 0 as we calculate max batch size on-the-fly", + type=int, + default=0, + ) + parser.add_argument( + "--calib_size", + help=( + "Number of samples for calibration. If a comma separated list of values is provided, " + "each value will be used as the calibration size for the corresponding dataset. " + "This argument will be parsed and converted as a list of ints." + ), + type=str, + default="512", + ) + parser.add_argument( + "--calib_seq", + help="Maximum sequence length for calibration.", + type=int, + default=512, + ) + parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--dataset", + help=( + f"name of a dataset, or a comma separated list of datasets. " + f"dataset choices are {get_supported_datasets()}" + ), + type=str, + default=None, + ) + parser.add_argument( + "--kv_cache_quant_cfg", + required=False, + default=None, + help="Specify KV cache quantization configuration, default to None if not provided", + ) + parser.add_argument( + "--trust_remote_code", + help="Set trust_remote_code for Huggingface models and tokenizers", + default=False, + action="store_true", + ) + parser.add_argument( + "--gpu_max_mem_percentage", + help=( + "Specify the percentage of available GPU memory to use for loading the model when " + "device_map is set to sequential. " + "By default, 80%% of the available GPU memory is used." + ), + type=float, + default=0.8, + ) + parser.add_argument( + "--use_seq_device_map", + help=( + "Use device_map=sequential to load the model onto GPUs. This ensures the model is loaded " + "utilizing the percentage of available GPU memory as specified by the value passed with gpu_max_mem flag." + "Helpful in cases where device_map=auto loads model unevenly on GPUs causing GPU OOM during quantization." + ), + default=False, + action="store_true", + ) + + return parser.parse_args() + + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + # launch a memory monitor to read the currently used GPU memory. + launch_memory_monitor() + + # Force eager execution for all model types. + torch.compiler.set_stance("force_eager") + + # Quantize + quantize_and_export_model(args) + + +if __name__ == "__main__": + args = parse_args() + + args.dataset = args.dataset.split(",") if args.dataset else None + args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + main(args) diff --git a/examples/vllm_serve/vllm_fq_export.py b/examples/vllm_serve/vllm_fq_export.py new file mode 100644 index 000000000..feeac3e92 --- /dev/null +++ b/examples/vllm_serve/vllm_fq_export.py @@ -0,0 +1,337 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import time +import warnings +from typing import Any + +import numpy as np +import torch +from accelerate.hooks import remove_hook_from_module +from example_utils import ( + build_quant_cfg, + copy_custom_model_files, + get_model, + get_processor, + get_tokenizer, + is_enc_dec, + is_nemotron_vl, + run_nemotron_vl_preview, +) +from torch.utils.data import DataLoader +import transformers +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, + ProcessorMixin, + WhisperProcessor, +) +from accelerate import infer_auto_device_map, init_empty_weights +from accelerate.utils import get_max_memory +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import modelopt.torch.sparsity as mts +from modelopt.torch.export import ( + export_hf_checkpoint, + export_hf_vllm_fq_checkpoint, + export_tensorrt_llm_checkpoint, + get_model_type, +) +from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model +from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights +from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, + get_max_batch_size, + get_supported_datasets, +) +from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor +from modelopt.torch.utils.memory_monitor import launch_memory_monitor +from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader +from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader + +RAND_SEED = 1234 + +mto.enable_huggingface_checkpointing() + +def load_model( + ckpt_path, + device="cuda", + gpu_mem_percentage=0.8, + trust_remote_code=False, + use_seq_device_map=False, +): + print(f"Initializing model from {ckpt_path}") + + device_map = "auto" + if device == "cpu": + device_map = "cpu" + + # Prepare config kwargs for loading + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} + + # Load config once + try: + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + except Exception as e: + print(f"Error: Could not load config from {ckpt_path}: {e}") + raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e + + model_kwargs = config_kwargs.copy() + + if use_seq_device_map: + device_map = "sequential" + # If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU + max_memory = get_max_memory() + max_memory = {key: value * gpu_mem_percentage for key, value in max_memory.items()} + model_kwargs["max_memory"] = max_memory + + architecture = hf_config.architectures[0] + + if not hasattr(transformers, architecture) or "Deepseek" in architecture: + if not hasattr(transformers, architecture): + warnings.warn( + f"Architecture {architecture} not found in transformers: {transformers.__version__}. " + "Falling back to AutoModelForCausalLM." + ) + assert trust_remote_code, ( + "Please set trust_remote_code to True if you want to use this architecture" + ) + + auto_model_module = AutoModelForCausalLM + from_config = auto_model_module.from_config + else: + auto_model_module = getattr(transformers, architecture) + from_config = auto_model_module._from_config + + with init_empty_weights(): + # When computing the device_map, assuming bfloat16 precision by default, + # unless specified by the hf_config. + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + model_kwargs2 = model_kwargs.copy() + if auto_model_module != AutoModelForCausalLM: + model_kwargs2.pop("trust_remote_code", None) + model_kwargs2["torch_dtype"] = torch_dtype + model_kwargs2.pop("max_memory", None) + model = from_config(hf_config, **model_kwargs2) + + max_memory = get_max_memory() + inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) + + on_cpu = "cpu" in inferred_device_map.values() + + if on_cpu: + for _device in max_memory: + if isinstance(_device, int): + max_memory[_device] *= gpu_mem_percentage + + print( + "Model does not fit to the GPU mem. " + f"We apply the following memory limit for calibration: \n{max_memory}\n" + "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " + "reduce the calibration `batch_size` manually." + ) + model_kwargs["max_memory"] = max_memory + + model = auto_model_module.from_pretrained( + ckpt_path, + device_map=device_map, + **model_kwargs, + ) + model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + + if device == "cuda" and not is_model_on_gpu(model): + print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") + + return model + + +def is_model_on_gpu(model) -> bool: + """Returns if the model is fully loaded on GPUs.""" + return all("cuda" in str(param.device) for param in model.parameters()) + + +def quantize_and_export_model( + args: argparse.Namespace, +): + model = load_model( args.pyt_ckpt_path, + device=args.device, + gpu_mem_percentage=args.gpu_max_mem_percentage, + trust_remote_code=args.trust_remote_code, + use_seq_device_map=args.use_seq_device_map,) + + args.batch_size = get_max_batch_size( + model, + max_sample_length=args.calib_seq, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + device = model.device + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + device=device, + include_labels=False, + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + mtq_cfg = getattr(mtq, args.quant_cfg) # type: ignore [arg-type] + if args.kv_cache_quant_cfg is not None: + kv_cache_quant_cfg = getattr(mtq, args.kv_cache_quant_cfg) # type: ignore [arg-type] + mtq_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + mtq_cfg["quant_cfg"], kv_cache_quant_cfg["quant_cfg"] + ) + input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0]) + generated_str_before_ptq = model.run(input_str) + + quantized_model = mtq.quantize(model, mtq_cfg, calibrate_loop) + mtq.print_quant_summary(quantized_model) + generated_str_after_ptq = model.run(input_str) + + print("--------") + print(f"example test input: {input_str}") + print("--------") + print(f"example outputs before ptq: {generated_str_before_ptq}") + print("--------") + print(f"example outputs after ptq: {generated_str_after_ptq}") + + export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) + print(f"Model exported to {args.export_path}") + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--pyt_ckpt_path", + help="Specify where the PyTorch checkpoint path is", + required=True, + ) + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--quant_cfg", + help=( + "Quantization configuration." + ), + default="FP8_DEFAULT_CFG", + ) + parser.add_argument( + "--batch_size", + help="Batch size for calibration. Default to 0 as we calculate max batch size on-the-fly", + type=int, + default=0, + ) + parser.add_argument( + "--calib_size", + help=( + "Number of samples for calibration. If a comma separated list of values is provided, " + "each value will be used as the calibration size for the corresponding dataset. " + "This argument will be parsed and converted as a list of ints." + ), + type=str, + default="512", + ) + parser.add_argument( + "--calib_seq", + help="Maximum sequence length for calibration.", + type=int, + default=512, + ) + parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--dataset", + help=( + f"name of a dataset, or a comma separated list of datasets. " + f"dataset choices are {get_supported_datasets()}" + ), + type=str, + default=None, + ) + parser.add_argument( + "--kv_cache_quant_cfg", + required=False, + default=None, + help="Specify KV cache quantization configuration, default to None if not provided", + ) + parser.add_argument( + "--trust_remote_code", + help="Set trust_remote_code for Huggingface models and tokenizers", + default=False, + action="store_true", + ) + parser.add_argument( + "--gpu_max_mem_percentage", + help=( + "Specify the percentage of available GPU memory to use for loading the model when " + "device_map is set to sequential. " + "By default, 80%% of the available GPU memory is used." + ), + type=float, + default=0.8, + ) + parser.add_argument( + "--use_seq_device_map", + help=( + "Use device_map=sequential to load the model onto GPUs. This ensures the model is loaded " + "utilizing the percentage of available GPU memory as specified by the value passed with gpu_max_mem flag." + "Helpful in cases where device_map=auto loads model unevenly on GPUs causing GPU OOM during quantization." + ), + default=False, + action="store_true", + ) + + return parser.parse_args() + + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + # launch a memory monitor to read the currently used GPU memory. + launch_memory_monitor() + + # Force eager execution for all model types. + torch.compiler.set_stance("force_eager") + + # Quantize + quantize_and_export_model( + args, + + ) + + +if __name__ == "__main__": + args = parse_args() + + args.dataset = args.dataset.split(",") if args.dataset else None + args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + main(args) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 03b191346..b30e7530d 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -15,6 +15,7 @@ """Export HuggingFace model to vLLM fakequant checkpoint.""" from pathlib import Path +from typing import Any import torch import torch.nn as nn @@ -26,6 +27,25 @@ __all__ = ["export_hf_vllm_fq_checkpoint"] +def cleanup_for_torch_save(x: Any) -> Any: + """Drop callables / local closures (e.g. `.new_forward`) before torch.save. + + ModelOpt stored state dict may contain local closures like `.new_forward` + which are not picklable. So we need to cleanup the state dict before saving. + """ + if isinstance(x, dict): + return { + k: cleanup_for_torch_save(v) + for k, v in x.items() + if not callable(v) and "" not in str(getattr(v, "__qualname__", "")) + } + if isinstance(x, list): + return [cleanup_for_torch_save(v) for v in x] + if isinstance(x, tuple): + return tuple(cleanup_for_torch_save(v) for v in x) + return x + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, @@ -48,8 +68,9 @@ def export_hf_vllm_fq_checkpoint( quantizer_state_dict = get_quantizer_state_dict(model) modelopt_state = mto.modelopt_state(model) - modelopt_state["modelopt_state_weights"] = quantizer_state_dict - torch.save(modelopt_state, f"{export_dir}/vllm_fq_modelopt_state.pth") + modelopt_state = cleanup_for_torch_save(modelopt_state) + modelopt_state["modelopt_state_weights"] = cleanup_for_torch_save(quantizer_state_dict) + torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") # remove quantizer from model for _, module in model.named_modules(): if is_quantlinear(module): From 56bb4d2f9c73e136248d988a4e7b8959c45c5c7e Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:36:45 +0000 Subject: [PATCH 10/15] removed vllm fq export from hf_ptq Signed-off-by: Kinjal Patel --- examples/llm_ptq/hf_ptq.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 3119d3457..d7aadf994 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -51,7 +51,6 @@ import modelopt.torch.sparsity as mts from modelopt.torch.export import ( export_hf_checkpoint, - export_hf_vllm_fq_checkpoint, export_tensorrt_llm_checkpoint, get_model_type, save_expert_token_count_table, @@ -1127,12 +1126,6 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) - parser.add_argument( - "--export_vllm_fq", - help="Export vLLM fakequant checkpoint.", - default=False, - action="store_true", - ) return parser.parse_args() From 9e249ee099ca8ec275d044a0154be45c6d116772 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:38:03 +0000 Subject: [PATCH 11/15] minor Signed-off-by: Kinjal Patel --- examples/vllm_serve/README.md | 26 +- examples/vllm_serve/vllm_fq_export.py | 337 -------------------------- 2 files changed, 11 insertions(+), 352 deletions(-) delete mode 100644 examples/vllm_serve/vllm_fq_export.py diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 60002b747..8c15d75e9 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -58,24 +58,20 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=, ## Load QAT/PTQ model and serve in vLLM (WIP) -Overwrite the calibrated amax value with prepared values from either QAT/PTQ. +Step 1: export the model with bf16 weights and quantizer state. To export the model: -Step 1: export the model with bf16 weights and amax values. To export the model: +- For **HF** models, use `hf_ptq_export.py`: -- For **HF** models, you can use `modelopt.torch.export.export_hf_vllm_fq_checkpoint`: - - ```python - import torch - from modelopt.torch.export import export_hf_vllm_fq_checkpoint - - with torch.inference_mode(): - export_hf_vllm_fq_checkpoint( - model, # The quantized model. - export_dir, # The directory where the exported files will be stored. - ) - ``` +```bash +python hf_ptq_export.py\ + --pyt_ckpt_path \ + --quant_cfg NVFP4_DEFAULT_CFG \ + --export_path \ + --trust_remote_code +``` - Or run the example script `examples/llm_ptq/hf_ptq.py` with the `--export_vllm_fq` **flag** to export a vLLM-fakequant-compatible ModelOpt state (it generates `vllm_fq_modelopt_state.pth`, which you can use via `MODELOPT_STATE_PATH`). + This creates `/vllm_fq_modelopt_state.pth` (ModelOpt quantizer state for vLLM fake-quant reload) and saves the HF-exported model under `` (config/tokenizer/weights). + Note: `--pyt_ckpt_path` can point to either an HF checkpoint or a ModelOpt-saved checkpoint (e.g., a QAT/QAD checkpoint produced by `examples/llm_qat/main.py`). If the input checkpoint is already quantized, the script will **skip re-quantization** and only export artifacts for vLLM fakequant reload. - For **MCore** models, use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq`: diff --git a/examples/vllm_serve/vllm_fq_export.py b/examples/vllm_serve/vllm_fq_export.py deleted file mode 100644 index feeac3e92..000000000 --- a/examples/vllm_serve/vllm_fq_export.py +++ /dev/null @@ -1,337 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import random -import time -import warnings -from typing import Any - -import numpy as np -import torch -from accelerate.hooks import remove_hook_from_module -from example_utils import ( - build_quant_cfg, - copy_custom_model_files, - get_model, - get_processor, - get_tokenizer, - is_enc_dec, - is_nemotron_vl, - run_nemotron_vl_preview, -) -from torch.utils.data import DataLoader -import transformers -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoProcessor, - PreTrainedTokenizer, - PreTrainedTokenizerBase, - PreTrainedTokenizerFast, - ProcessorMixin, - WhisperProcessor, -) -from accelerate import infer_auto_device_map, init_empty_weights -from accelerate.utils import get_max_memory -import modelopt.torch.opt as mto -import modelopt.torch.quantization as mtq -import modelopt.torch.sparsity as mts -from modelopt.torch.export import ( - export_hf_checkpoint, - export_hf_vllm_fq_checkpoint, - export_tensorrt_llm_checkpoint, - get_model_type, -) -from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration -from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights -from modelopt.torch.quantization.utils import is_quantized -from modelopt.torch.utils.dataset_utils import ( - create_forward_loop, - get_dataset_dataloader, - get_max_batch_size, - get_supported_datasets, -) -from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor -from modelopt.torch.utils.memory_monitor import launch_memory_monitor -from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader -from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader - -RAND_SEED = 1234 - -mto.enable_huggingface_checkpointing() - -def load_model( - ckpt_path, - device="cuda", - gpu_mem_percentage=0.8, - trust_remote_code=False, - use_seq_device_map=False, -): - print(f"Initializing model from {ckpt_path}") - - device_map = "auto" - if device == "cpu": - device_map = "cpu" - - # Prepare config kwargs for loading - config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} - - # Load config once - try: - hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) - except Exception as e: - print(f"Error: Could not load config from {ckpt_path}: {e}") - raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e - - model_kwargs = config_kwargs.copy() - - if use_seq_device_map: - device_map = "sequential" - # If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU - max_memory = get_max_memory() - max_memory = {key: value * gpu_mem_percentage for key, value in max_memory.items()} - model_kwargs["max_memory"] = max_memory - - architecture = hf_config.architectures[0] - - if not hasattr(transformers, architecture) or "Deepseek" in architecture: - if not hasattr(transformers, architecture): - warnings.warn( - f"Architecture {architecture} not found in transformers: {transformers.__version__}. " - "Falling back to AutoModelForCausalLM." - ) - assert trust_remote_code, ( - "Please set trust_remote_code to True if you want to use this architecture" - ) - - auto_model_module = AutoModelForCausalLM - from_config = auto_model_module.from_config - else: - auto_model_module = getattr(transformers, architecture) - from_config = auto_model_module._from_config - - with init_empty_weights(): - # When computing the device_map, assuming bfloat16 precision by default, - # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) - model_kwargs2 = model_kwargs.copy() - if auto_model_module != AutoModelForCausalLM: - model_kwargs2.pop("trust_remote_code", None) - model_kwargs2["torch_dtype"] = torch_dtype - model_kwargs2.pop("max_memory", None) - model = from_config(hf_config, **model_kwargs2) - - max_memory = get_max_memory() - inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) - - on_cpu = "cpu" in inferred_device_map.values() - - if on_cpu: - for _device in max_memory: - if isinstance(_device, int): - max_memory[_device] *= gpu_mem_percentage - - print( - "Model does not fit to the GPU mem. " - f"We apply the following memory limit for calibration: \n{max_memory}\n" - "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " - "reduce the calibration `batch_size` manually." - ) - model_kwargs["max_memory"] = max_memory - - model = auto_model_module.from_pretrained( - ckpt_path, - device_map=device_map, - **model_kwargs, - ) - model.eval() - - # If device_map was disabled (None), manually move model to target device - if device_map is None and device != "cpu": - print(f"Moving model to {device} device...") - model = model.to(device) - - if device == "cuda" and not is_model_on_gpu(model): - print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") - - return model - - -def is_model_on_gpu(model) -> bool: - """Returns if the model is fully loaded on GPUs.""" - return all("cuda" in str(param.device) for param in model.parameters()) - - -def quantize_and_export_model( - args: argparse.Namespace, -): - model = load_model( args.pyt_ckpt_path, - device=args.device, - gpu_mem_percentage=args.gpu_max_mem_percentage, - trust_remote_code=args.trust_remote_code, - use_seq_device_map=args.use_seq_device_map,) - - args.batch_size = get_max_batch_size( - model, - max_sample_length=args.calib_seq, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) - - print(f"Use calib batch_size {args.batch_size}") - tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) - device = model.device - calib_dataloader = get_dataset_dataloader( - dataset_name=args.dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_samples=args.calib_size, - device=device, - include_labels=False, - ) - calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - mtq_cfg = getattr(mtq, args.quant_cfg) # type: ignore [arg-type] - if args.kv_cache_quant_cfg is not None: - kv_cache_quant_cfg = getattr(mtq, args.kv_cache_quant_cfg) # type: ignore [arg-type] - mtq_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - mtq_cfg["quant_cfg"], kv_cache_quant_cfg["quant_cfg"] - ) - input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0]) - generated_str_before_ptq = model.run(input_str) - - quantized_model = mtq.quantize(model, mtq_cfg, calibrate_loop) - mtq.print_quant_summary(quantized_model) - generated_str_after_ptq = model.run(input_str) - - print("--------") - print(f"example test input: {input_str}") - print("--------") - print(f"example outputs before ptq: {generated_str_before_ptq}") - print("--------") - print(f"example outputs after ptq: {generated_str_after_ptq}") - - export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) - print(f"Model exported to {args.export_path}") - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--pyt_ckpt_path", - help="Specify where the PyTorch checkpoint path is", - required=True, - ) - parser.add_argument("--device", default="cuda") - parser.add_argument( - "--quant_cfg", - help=( - "Quantization configuration." - ), - default="FP8_DEFAULT_CFG", - ) - parser.add_argument( - "--batch_size", - help="Batch size for calibration. Default to 0 as we calculate max batch size on-the-fly", - type=int, - default=0, - ) - parser.add_argument( - "--calib_size", - help=( - "Number of samples for calibration. If a comma separated list of values is provided, " - "each value will be used as the calibration size for the corresponding dataset. " - "This argument will be parsed and converted as a list of ints." - ), - type=str, - default="512", - ) - parser.add_argument( - "--calib_seq", - help="Maximum sequence length for calibration.", - type=int, - default=512, - ) - parser.add_argument("--export_path", default="exported_model") - parser.add_argument( - "--dataset", - help=( - f"name of a dataset, or a comma separated list of datasets. " - f"dataset choices are {get_supported_datasets()}" - ), - type=str, - default=None, - ) - parser.add_argument( - "--kv_cache_quant_cfg", - required=False, - default=None, - help="Specify KV cache quantization configuration, default to None if not provided", - ) - parser.add_argument( - "--trust_remote_code", - help="Set trust_remote_code for Huggingface models and tokenizers", - default=False, - action="store_true", - ) - parser.add_argument( - "--gpu_max_mem_percentage", - help=( - "Specify the percentage of available GPU memory to use for loading the model when " - "device_map is set to sequential. " - "By default, 80%% of the available GPU memory is used." - ), - type=float, - default=0.8, - ) - parser.add_argument( - "--use_seq_device_map", - help=( - "Use device_map=sequential to load the model onto GPUs. This ensures the model is loaded " - "utilizing the percentage of available GPU memory as specified by the value passed with gpu_max_mem flag." - "Helpful in cases where device_map=auto loads model unevenly on GPUs causing GPU OOM during quantization." - ), - default=False, - action="store_true", - ) - - return parser.parse_args() - - -def main(args: argparse.Namespace): - if not torch.cuda.is_available(): - raise OSError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - # launch a memory monitor to read the currently used GPU memory. - launch_memory_monitor() - - # Force eager execution for all model types. - torch.compiler.set_stance("force_eager") - - # Quantize - quantize_and_export_model( - args, - - ) - - -if __name__ == "__main__": - args = parse_args() - - args.dataset = args.dataset.split(",") if args.dataset else None - args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] - main(args) From 9b42a09dccd24db8a789131a9252746e6e3d7ea2 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:55:52 +0000 Subject: [PATCH 12/15] cleanup Signed-off-by: Kinjal Patel --- examples/vllm_serve/hf_ptq_export.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/vllm_serve/hf_ptq_export.py b/examples/vllm_serve/hf_ptq_export.py index 7ee5c091d..fda5a6dec 100644 --- a/examples/vllm_serve/hf_ptq_export.py +++ b/examples/vllm_serve/hf_ptq_export.py @@ -128,6 +128,7 @@ def is_model_on_gpu(model) -> bool: """Returns if the model is fully loaded on GPUs.""" return all("cuda" in str(param.device) for param in model.parameters()) + def get_tokenizer(ckpt_path, trust_remote_code=False): """Returns the tokenizer from the model ckpt_path.""" print(f"Initializing tokenizer from {ckpt_path}") @@ -143,6 +144,7 @@ def get_tokenizer(ckpt_path, trust_remote_code=False): return tokenizer + def quantize_and_export_model( args: argparse.Namespace, ): @@ -188,7 +190,7 @@ def quantize_and_export_model( else: print("Model is already quantized, Skipping quantization...") quantized_model = model - + mtq.print_quant_summary(quantized_model) if not model_is_already_quantized: print("--------") @@ -199,11 +201,6 @@ def quantize_and_export_model( print(f"example outputs after ptq: {generated_str_after_ptq}") export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) - # from modelopt.torch.quantization.utils import get_quantizer_state_dict - # quantized_model.save_pretrained(args.export_path, state_dict=quantized_model.state_dict(), save_modelopt_state=False) - # modelopt_state = mto.modelopt_state(quantized_model) - # modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(quantized_model) - # torch.save(modelopt_state, f"{args.export_path}/modelopt_state.pth") tokenizer.save_pretrained(args.export_path) print(f"Model exported to {args.export_path}") From fa9b770eae699a742b5ed3e5cad0d92b206d7d65 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 23 Feb 2026 17:01:04 +0000 Subject: [PATCH 13/15] minor Signed-off-by: Kinjal Patel --- modelopt/torch/export/plugins/vllm_fakequant_megatron.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 1fef361f0..039e902fb 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -99,6 +99,7 @@ def _get_quantized_state( self, module: torch.nn.Module, dtype: torch.dtype = torch.float16, + prefix: str = "", ) -> tuple[dict[str, torch.Tensor], str, int]: """Return a state_dict, quantization format, and block_size of the module. @@ -111,6 +112,10 @@ def _get_quantized_state( """ name_to_value = {} qformat: str = self._get_quantization_format(module) + if qformat is None and "norm" not in prefix: + # Add exclude layers for vllm fakequant config. Note that if the prefix is not an empty + # string then it usually ends with "." which needs to be removed. + self.exclude_modules.append(prefix.removesuffix(".")) block_size = 0 if hasattr(module, "weight") and module.weight is not None: From 7196692ed83a62f44d711ffa035603cd9b3aa64f Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 23 Feb 2026 19:25:30 +0000 Subject: [PATCH 14/15] removed cleanup_for_torch_save Signed-off-by: Kinjal Patel --- .../torch/export/plugins/vllm_fakequant_hf.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index b30e7530d..545cc9598 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -27,25 +27,6 @@ __all__ = ["export_hf_vllm_fq_checkpoint"] -def cleanup_for_torch_save(x: Any) -> Any: - """Drop callables / local closures (e.g. `.new_forward`) before torch.save. - - ModelOpt stored state dict may contain local closures like `.new_forward` - which are not picklable. So we need to cleanup the state dict before saving. - """ - if isinstance(x, dict): - return { - k: cleanup_for_torch_save(v) - for k, v in x.items() - if not callable(v) and "" not in str(getattr(v, "__qualname__", "")) - } - if isinstance(x, list): - return [cleanup_for_torch_save(v) for v in x] - if isinstance(x, tuple): - return tuple(cleanup_for_torch_save(v) for v in x) - return x - - def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, @@ -68,8 +49,7 @@ def export_hf_vllm_fq_checkpoint( quantizer_state_dict = get_quantizer_state_dict(model) modelopt_state = mto.modelopt_state(model) - modelopt_state = cleanup_for_torch_save(modelopt_state) - modelopt_state["modelopt_state_weights"] = cleanup_for_torch_save(quantizer_state_dict) + modelopt_state["modelopt_state_weights"] = quantizer_state_dict torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") # remove quantizer from model for _, module in model.named_modules(): From 0f201886d468c9267a802509f18bc42601492477 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 23 Feb 2026 19:39:38 +0000 Subject: [PATCH 15/15] minor Signed-off-by: Kinjal Patel --- modelopt/torch/export/plugins/vllm_fakequant_hf.py | 1 - modelopt/torch/export/plugins/vllm_fakequant_megatron.py | 2 +- .../torch/export/test_vllm_fakequant_megatron_export.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 545cc9598..fb3ceef17 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -15,7 +15,6 @@ """Export HuggingFace model to vLLM fakequant checkpoint.""" from pathlib import Path -from typing import Any import torch import torch.nn as nn diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 039e902fb..c219f820d 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -58,7 +58,7 @@ def gather_mcore_vllm_fq_quantized_state_dict( if quantizer_state_dict is not None: merged_quantizer_state_dict.update(quantizer_state_dict) - torch.save(merged_quantizer_state_dict, save_directory + "/quantizer_state.pth") + torch.save(merged_quantizer_state_dict, save_directory + "/vllm_fq_modelopt_state.pth") else: # Other ranks just send their amax values torch.distributed.gather_object(quantizer_state_dict, None, dst=0) diff --git a/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py index 6336dce68..29989ef94 100644 --- a/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py +++ b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py @@ -96,7 +96,7 @@ def forward_loop(model): ) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quantizer_state.pth" + quant_amax_file = export_dir / "vllm_fq_modelopt_state.pth" assert quant_amax_file.exists(), f"quantizer_state.pth file should be created in {export_dir}" # make sure hf_quant_config.json file does not exist