From 26242e378553fb3a08394c54421d0e362ebba2e9 Mon Sep 17 00:00:00 2001 From: James Shen Date: Wed, 8 Apr 2026 18:53:38 -0700 Subject: [PATCH 1/2] Add VLM base model support for auto_quantize in hf_ptq For VLMs like Gemma4 where the extracted language_model lacks lm_head, use the full_model's lm_head to compute logits/loss from hidden states. How to run: cd /opt/Model-Optimizer/examples/llm_ptq && python hf_ptq.py \ --pyt_ckpt_path /lustre/fsw/portfolios/coreai/users/yueshen/models/gemma-4-31B-it \ --qformat nvfp4,fp8 \ --auto_quantize_bits 6.0 \ --calib_size 512 \ --dataset cnn_dailymail \ --export_path /lustre/fsw/portfolios/coreai/users/yueshen/models/gemma-4-31B-it-autoquant-6.0 Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: James Shen --- examples/llm_ptq/hf_ptq.py | 68 +++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f19e82c5d4..5b067fed64 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -301,6 +301,7 @@ def auto_quantize( auto_quantize_method="gradient", auto_quantize_score_size=128, auto_quantize_checkpoint=None, + full_model: torch.nn.Module | None = None, ): """Auto search quantization of multiple formats.""" @@ -338,23 +339,57 @@ def auto_quantize( for qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" - def loss_func(output, data): - # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` - # which contains the loss attribute. - return output.loss - - if auto_quantize_method == "gradient": - # For gradient-based method, return full output with loss - def forward_step(model, batch): - return model(**batch) - elif auto_quantize_method == "kl_div": - # For KL divergence method, return only logits - def forward_step(model, batch): - return model(**batch).logits + # For VLMs like Gemma4, the extracted language_model is a base text model without + # lm_head, so it cannot produce logits or loss directly. In that case, use the + # full_model's lm_head to compute logits/loss from the language model's hidden states. + is_base_model = ( + full_model is not None + and language_model is not full_model + and not hasattr(language_model, "lm_head") + and hasattr(full_model, "lm_head") + ) + + if is_base_model: + lm_head = full_model.lm_head + + def loss_func(output, data): + logits = lm_head(output.last_hidden_state) + labels = data["labels"] + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if auto_quantize_method == "gradient": + def forward_step(model, batch): + return model(**batch) + elif auto_quantize_method == "kl_div": + def forward_step(model, batch): + hidden_states = model(**batch).last_hidden_state + return lm_head(hidden_states) + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) else: - raise ValueError( - f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" - ) + def loss_func(output, data): + # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` + # which contains the loss attribute. + return output.loss + + if auto_quantize_method == "gradient": + # For gradient-based method, return full output with loss + def forward_step(model, batch): + return model(**batch) + elif auto_quantize_method == "kl_div": + # For KL divergence method, return only logits + def forward_step(model, batch): + return model(**batch).logits + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) language_model, _ = mtq.auto_quantize( language_model, @@ -1048,6 +1083,7 @@ def quantize_main( args, language_model, calib_dataloader, + full_model=full_model, ) else: From 3022a34d845f1d16eaf55bc79fe027e2f08ba4ac Mon Sep 17 00:00:00 2001 From: James Shen Date: Wed, 8 Apr 2026 20:46:40 -0700 Subject: [PATCH 2/2] Fix code quality: add mypy assert and ruff blank lines Add assert for full_model to satisfy mypy union-attr check, and add blank lines before nested def statements per ruff formatting rules. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: James Shen --- examples/llm_ptq/hf_ptq.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5b067fed64..984243a320 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -350,6 +350,7 @@ def auto_quantize( ) if is_base_model: + assert full_model is not None lm_head = full_model.lm_head def loss_func(output, data): @@ -362,17 +363,22 @@ def loss_func(output, data): ) if auto_quantize_method == "gradient": + def forward_step(model, batch): return model(**batch) + elif auto_quantize_method == "kl_div": + def forward_step(model, batch): hidden_states = model(**batch).last_hidden_state return lm_head(hidden_states) + else: raise ValueError( f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" ) else: + def loss_func(output, data): # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` # which contains the loss attribute. @@ -380,12 +386,16 @@ def loss_func(output, data): if auto_quantize_method == "gradient": # For gradient-based method, return full output with loss + def forward_step(model, batch): return model(**batch) + elif auto_quantize_method == "kl_div": # For KL divergence method, return only logits + def forward_step(model, batch): return model(**batch).logits + else: raise ValueError( f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"