Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
</p>

## Latest News
* 04/02/2026 [6.0.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.1): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support.
* 04/03/2026 [6.0.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.2): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `Gemma4`, `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support.
* 03/19/2026 [5.8.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.8.0): ✨HF Transformers 5.3.0 support with auto-defusing of `fused` models via pypi pkg: [Defuser](https://github.com/ModelCloud/Defuser). Qwen 3.5 family support added. New fast HF `cpu` kernels for GPTQ/AWQ added. Experimental INT8 `cpu` kernel added for GPTQ.
* 03/09/2026 [main]: ✨Qwen 3.5 MoE model support added. New HF Kernel support added for AWQ.
HF Kernel for both gptq/awq are now used by default for cpu devices for best performance. New INT8 kernel ported from Intel for gptq.
Expand Down
7 changes: 7 additions & 0 deletions gptqmodel/looper/forward_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def run_single(
additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=exec_device)

additional_inputs["use_cache"] = False
additional_inputs = self.looper.gptq_model.prepare_layer_replay_kwargs(
layer=module,
layer_input=layer_input,
additional_inputs=additional_inputs,
target_device=exec_device,
)

if not preserve_module_devices:
rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True)
Expand Down Expand Up @@ -489,6 +495,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
layer_input_kwargs[batch_idx],
attention_masks[batch_idx],
position_ids[batch_idx] if position_ids else None,
gptq_model=self.looper.gptq_model,
support_batch_quantize=self.looper.support_batch_quantize,
is_lm_head_module=is_lm_head_module,
need_output=need_outputs,
Expand Down
16 changes: 11 additions & 5 deletions gptqmodel/looper/stage_inputs_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def store_input_hook(module, args, kwargs):
else:
batch_device = data_device

layer_input: List[torch.Tensor] = []
if kwargs.get("hidden_states") is not None:
layer_input.append(move_to(kwargs["hidden_states"], device=batch_device))
else:
layer_input.append(move_to(args[0], device=batch_device))
layer_input = self.gptq_model.capture_first_layer_positional_inputs(
args=args,
kwargs=kwargs,
batch_device=batch_device,
)

layer_inputs.append(layer_input)

Expand All @@ -161,6 +161,12 @@ def store_input_hook(module, args, kwargs):
for (k, v) in kwargs.items():
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, device=batch_device)
one_kwargs = self.gptq_model.capture_first_layer_input_kwargs(
args=args,
kwargs=kwargs,
batch_device=batch_device,
layer_input_kwargs=one_kwargs,
)
layer_input_kwargs.append(one_kwargs)

# In normal repeating layer/sbuset early stop happens on the last module forward
Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from .definitions.falcon_h1 import FalconH1QModel # noqa: E402
from .definitions.gemma2 import Gemma2QModel # noqa: E402
from .definitions.gemma3 import Gemma3ForConditionalGenerationGPTQ, Gemma3QModel # noqa: E402
from .definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel # noqa: E402
from .definitions.glm import GlmQModel # noqa: E402
from .definitions.glm4_moe import GLM4MoEGPTQ # noqa: E402
from .definitions.glm4_moe_lite import Glm4MoeLiteQModel # noqa: E402
Expand Down Expand Up @@ -210,6 +211,8 @@
"gemma2": Gemma2QModel,
"gemma3_text": Gemma3QModel,
"gemma3": Gemma3ForConditionalGenerationGPTQ,
"gemma4_text": Gemma4TextQModel,
"gemma4": Gemma4ForConditionalGenerationGPTQ,
"phi": PhiQModel,
"phi3": Phi3QModel,
"phi4mm": Phi4MMGPTQ,
Expand Down
36 changes: 36 additions & 0 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,42 @@ def pre_quantize_generate_hook_end(self):
# offload_to_disk(model=self.model, module=self.get_base_modules(model=self.model), disk_path=self.quantize_config.offload_to_disk_path)
pass

def capture_first_layer_positional_inputs(
self,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
batch_device: torch.device,
) -> List[torch.Tensor]:
"""Normalize first-layer positional inputs so cached forwards can replay decoder layers directly."""

if kwargs.get("hidden_states") is not None:
return [move_to(kwargs["hidden_states"], device=batch_device)]
if args:
return [move_to(args[0], device=batch_device)]
return []

def capture_first_layer_input_kwargs(
self,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
batch_device: torch.device,
layer_input_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Allow model definitions to persist extra first-layer replay metadata during calibration capture."""

return layer_input_kwargs

def prepare_layer_replay_kwargs(
self,
layer: nn.Module,
layer_input: List[torch.Tensor],
additional_inputs: Dict[str, Any],
target_device: torch.device,
) -> Dict[str, Any]:
"""Allow model definitions to refresh layer-specific kwargs before cached layer replay."""

return additional_inputs

def lm_head_pre_quantize_generate_hook(self, inputs: List[List[torch.tensor]]) -> List[List[torch.tensor]]:
if self.pre_lm_head_norm_module:
norm, _ = get_module_by_name_prefix(self.model, [self.pre_lm_head_norm_module])
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .ernie4_5_moe import Ernie4_5_MoeQModel
from .gemma2 import Gemma2QModel
from .gemma3 import Gemma3QModel
from .gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel
from .glm import GlmQModel
from .gpt2 import GPT2QModel
from .gpt_bigcode import GptBigCodeQModel
Expand Down
247 changes: 247 additions & 0 deletions gptqmodel/models/definitions/gemma4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import torch
from types import MethodType

from ..base import BaseQModel
from ...utils.device import get_device
from ...utils.model import get_module_by_name_prefix, move_to, nested_move_to
from . import LlamaQModel


_GEMMA4_ALL_PER_LAYER_INPUTS = "__gptqmodel_gemma4_all_per_layer_inputs"


def _gemma4_module_tree():
"""Return the Gemma 4 decoder traversal with optional attention and per-layer input modules."""

return [
"model",
"layers",
"#",
{
"input_layernorm": ("input_layernorm:!",),
"self_attn": (
"q_norm:!",
"q_proj:0",
"k_norm:!",
"k_proj:0",
"v_norm:!",
"v_proj:0",
"o_proj:1",
),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",),
"mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"),
"post_feedforward_layernorm": ("post_feedforward_layernorm:!",),
"per_layer_input_gate": ("per_layer_input_gate:0",),
"post_per_layer_input_norm": ("post_per_layer_input_norm:!",),
"per_layer_projection": ("per_layer_projection:1",),
},
]


def _capture_gemma4_positional_inputs(model_def, args, kwargs, batch_device):
"""Preserve Gemma 4 per-layer adapter inputs that flow through decoder layers positionally."""

layer_input = super(type(model_def), model_def).capture_first_layer_positional_inputs(args, kwargs, batch_device)
per_layer_input = args[1] if len(args) > 1 else kwargs.get("per_layer_input")
if per_layer_input is not None:
layer_input.append(move_to(per_layer_input, device=batch_device))
return layer_input


def _prepare_gemma4_replay_kwargs(model_def, layer, layer_input, additional_inputs, target_device):
"""Refresh Gemma 4 rotary kwargs per layer so replay follows sliding/full attention boundaries."""

rotary_path = getattr(model_def, "rotary_embedding", None)
if not rotary_path or not layer_input:
return additional_inputs

rotary, _ = get_module_by_name_prefix(model_def.model, [rotary_path])
if rotary is None:
return additional_inputs

layer_type = getattr(getattr(layer, "self_attn", None), "layer_type", None)
if layer_type is None:
return additional_inputs

hidden_states = layer_input[0]
seq_len = hidden_states.shape[1] if hidden_states.dim() >= 2 else hidden_states.shape[0]
batch_dim = hidden_states.shape[0] if hidden_states.dim() >= 2 else 1

position_ids = additional_inputs.get("position_ids")
if position_ids is None or position_ids.shape[-1] != seq_len:
position_ids = torch.arange(seq_len, device=target_device, dtype=torch.long).unsqueeze(0).expand(batch_dim, -1)
additional_inputs["position_ids"] = position_ids

try:
rotary_device = get_device(rotary)
except Exception:
rotary_device = position_ids.device

rotary_position_ids = move_to(position_ids, device=rotary_device)
rotary_input = torch.empty(1, device=rotary_device, dtype=hidden_states.dtype)
additional_inputs["position_embeddings"] = nested_move_to(
rotary(rotary_input, rotary_position_ids, layer_type),
device=target_device,
)

if len(layer_input) == 1:
all_per_layer_inputs = additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None)
layer_index = getattr(getattr(layer, "self_attn", None), "layer_idx", None)
if all_per_layer_inputs is not None and layer_index is not None:
additional_inputs["per_layer_input"] = move_to(
all_per_layer_inputs[:, :, layer_index, :],
device=target_device,
)
else:
additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None)

return additional_inputs


def _resolve_gemma4_language_model(model_def):
"""Return the Gemma 4 text stack that owns per-layer input projection state."""

if hasattr(model_def.model, "model") and hasattr(model_def.model.model, "language_model"):
return model_def.model.model.language_model
return model_def.model.model


def _patch_gemma4_per_layer_input_capture(model_def):
"""Capture projected per-layer inputs during calibration so later decoder replays can slice them by layer."""

language_model = _resolve_gemma4_language_model(model_def)
if getattr(language_model, "_gptqmodel_project_per_layer_inputs_patched", False):
return

original = language_model.project_per_layer_inputs

def patched(self, inputs_embeds, per_layer_inputs=None):
result = original(inputs_embeds, per_layer_inputs)
setattr(self, "_gptqmodel_cached_all_per_layer_inputs", result)
return result

language_model._gptqmodel_original_project_per_layer_inputs = original
language_model.project_per_layer_inputs = MethodType(patched, language_model)
language_model._gptqmodel_project_per_layer_inputs_patched = True


def _restore_gemma4_per_layer_input_capture(model_def):
"""Restore Gemma 4 per-layer input helpers after calibration capture completes."""

language_model = _resolve_gemma4_language_model(model_def)
original = getattr(language_model, "_gptqmodel_original_project_per_layer_inputs", None)
if original is not None:
language_model.project_per_layer_inputs = original
delattr(language_model, "_gptqmodel_original_project_per_layer_inputs")
if hasattr(language_model, "_gptqmodel_project_per_layer_inputs_patched"):
delattr(language_model, "_gptqmodel_project_per_layer_inputs_patched")
if hasattr(language_model, "_gptqmodel_cached_all_per_layer_inputs"):
delattr(language_model, "_gptqmodel_cached_all_per_layer_inputs")


class Gemma4TextQModel(LlamaQModel):
"""Quantization definition for text-only Gemma 4 checkpoints."""

# Gemma 4 mixes optional KV projections and per-layer residual adapters across variants.
layer_modules_strict = False
# Gemma 4 input preparation uses per-layer embeddings, so batch quantization stays conservative.
support_batch_quantize = False
pre_lm_head_norm_module = "model.norm"
rotary_embedding = "model.rotary_emb"
module_tree = _gemma4_module_tree()

def capture_first_layer_positional_inputs(self, args, kwargs, batch_device):
"""Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation."""

return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device)

def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs):
"""Persist Gemma 4 per-layer adapter tensors for later decoder replays."""

layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs)
language_model = _resolve_gemma4_language_model(self)
all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None)
if all_per_layer_inputs is not None:
layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device)
return layer_input_kwargs

def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device):
"""Refresh Gemma 4 layer kwargs during cached replay."""

return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device)

def pre_quantize_generate_hook_start(self):
_patch_gemma4_per_layer_input_capture(self)

def pre_quantize_generate_hook_end(self):
_restore_gemma4_per_layer_input_capture(self)
super().pre_quantize_generate_hook_end()


class Gemma4ForConditionalGenerationGPTQ(BaseQModel):
"""Quantization definition for composite Gemma 4 checkpoints."""

# Gemma 4 composite checkpoints share the same decoder quirks as the text-only model.
layer_modules_strict = False
support_batch_quantize = False
pre_lm_head_norm_module = "model.language_model.norm"
rotary_embedding = "model.language_model.rotary_emb"

module_tree = [
"model",
"language_model",
"layers",
"#",
{
"input_layernorm": ("input_layernorm:!",),
"self_attn": (
"q_norm:!",
"q_proj:0",
"k_norm:!",
"k_proj:0",
"v_norm:!",
"v_proj:0",
"o_proj:1",
),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",),
"mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"),
"post_feedforward_layernorm": ("post_feedforward_layernorm:!",),
"per_layer_input_gate": ("per_layer_input_gate:0",),
"post_per_layer_input_norm": ("post_per_layer_input_norm:!",),
"per_layer_projection": ("per_layer_projection:1",),
},
]

def capture_first_layer_positional_inputs(self, args, kwargs, batch_device):
"""Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation."""

return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device)

def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs):
"""Persist Gemma 4 per-layer adapter tensors for later decoder replays."""

layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs)
language_model = _resolve_gemma4_language_model(self)
all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None)
if all_per_layer_inputs is not None:
layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device)
return layer_input_kwargs

def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device):
"""Refresh Gemma 4 layer kwargs during cached replay."""

return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device)

def pre_quantize_generate_hook_start(self):
_patch_gemma4_per_layer_input_capture(self)

def pre_quantize_generate_hook_end(self):
_restore_gemma4_per_layer_input_capture(self)
super().pre_quantize_generate_hook_end()
8 changes: 8 additions & 0 deletions gptqmodel/quantization/awq/quantize/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@
from gptqmodel.quantization.awq.utils.utils import get_best_device


try:
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
except Exception: # pragma: no cover - older transformers builds do not expose Gemma 4 yet
Gemma4RMSNorm = None


allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm]
if Gemma4RMSNorm is not None:
allowed_norms.append(Gemma4RMSNorm)
allowed_act_fns = [
nn.GELU,
BloomGelu,
Expand Down
Loading
Loading