Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def auto_quantize(
auto_quantize_method="gradient",
auto_quantize_score_size=128,
auto_quantize_checkpoint=None,
full_model: torch.nn.Module | None = None,
):
"""Auto search quantization of multiple formats."""

Expand Down Expand Up @@ -338,23 +339,67 @@ def auto_quantize(
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss

if auto_quantize_method == "gradient":
# For gradient-based method, return full output with loss
def forward_step(model, batch):
return model(**batch)
elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits
def forward_step(model, batch):
return model(**batch).logits
# For VLMs like Gemma4, the extracted language_model is a base text model without
# lm_head, so it cannot produce logits or loss directly. In that case, use the
# full_model's lm_head to compute logits/loss from the language model's hidden states.
is_base_model = (
full_model is not None
and language_model is not full_model
and not hasattr(language_model, "lm_head")
and hasattr(full_model, "lm_head")
)

if is_base_model:
assert full_model is not None
lm_head = full_model.lm_head

def loss_func(output, data):
logits = lm_head(output.last_hidden_state)
labels = data["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)

if auto_quantize_method == "gradient":

def forward_step(model, batch):
return model(**batch)

elif auto_quantize_method == "kl_div":

def forward_step(model, batch):
hidden_states = model(**batch).last_hidden_state
return lm_head(hidden_states)

else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)
else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss

if auto_quantize_method == "gradient":
# For gradient-based method, return full output with loss

def forward_step(model, batch):
return model(**batch)

elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits

def forward_step(model, batch):
return model(**batch).logits

else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)

language_model, _ = mtq.auto_quantize(
language_model,
Expand Down Expand Up @@ -1048,6 +1093,7 @@ def quantize_main(
args,
language_model,
calib_dataloader,
full_model=full_model,
)

else:
Expand Down
Loading