diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f19e82c5d4..984243a320 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -301,6 +301,7 @@ def auto_quantize( auto_quantize_method="gradient", auto_quantize_score_size=128, auto_quantize_checkpoint=None, + full_model: torch.nn.Module | None = None, ): """Auto search quantization of multiple formats.""" @@ -338,23 +339,67 @@ def auto_quantize( for qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" - def loss_func(output, data): - # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` - # which contains the loss attribute. - return output.loss - - if auto_quantize_method == "gradient": - # For gradient-based method, return full output with loss - def forward_step(model, batch): - return model(**batch) - elif auto_quantize_method == "kl_div": - # For KL divergence method, return only logits - def forward_step(model, batch): - return model(**batch).logits + # For VLMs like Gemma4, the extracted language_model is a base text model without + # lm_head, so it cannot produce logits or loss directly. In that case, use the + # full_model's lm_head to compute logits/loss from the language model's hidden states. + is_base_model = ( + full_model is not None + and language_model is not full_model + and not hasattr(language_model, "lm_head") + and hasattr(full_model, "lm_head") + ) + + if is_base_model: + assert full_model is not None + lm_head = full_model.lm_head + + def loss_func(output, data): + logits = lm_head(output.last_hidden_state) + labels = data["labels"] + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if auto_quantize_method == "gradient": + + def forward_step(model, batch): + return model(**batch) + + elif auto_quantize_method == "kl_div": + + def forward_step(model, batch): + hidden_states = model(**batch).last_hidden_state + return lm_head(hidden_states) + + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) else: - raise ValueError( - f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" - ) + + def loss_func(output, data): + # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` + # which contains the loss attribute. + return output.loss + + if auto_quantize_method == "gradient": + # For gradient-based method, return full output with loss + + def forward_step(model, batch): + return model(**batch) + + elif auto_quantize_method == "kl_div": + # For KL divergence method, return only logits + + def forward_step(model, batch): + return model(**batch).logits + + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) language_model, _ = mtq.auto_quantize( language_model, @@ -1048,6 +1093,7 @@ def quantize_main( args, language_model, calib_dataloader, + full_model=full_model, ) else: