From 7fab0df8eaf5834b36bb2e3623fad644bf53eda4 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Thu, 2 Apr 2026 22:39:19 +0000
Subject: [PATCH 1/3] add gemma4
---
gptqmodel/looper/forward_executor.py | 7 +
gptqmodel/looper/stage_inputs_capture.py | 16 +-
gptqmodel/models/auto.py | 3 +
gptqmodel/models/base.py | 36 +++
gptqmodel/models/definitions/__init__.py | 1 +
gptqmodel/models/definitions/gemma4.py | 247 +++++++++++++++++++
gptqmodel/quantization/awq/quantize/scale.py | 8 +
gptqmodel/quantization/gptq.py | 61 ++++-
gptqmodel/utils/looper_helpers.py | 8 +
tests/models/test_gemma4_variants.py | 111 +++++++++
tests/test_gemma4_support.py | 184 ++++++++++++++
11 files changed, 670 insertions(+), 12 deletions(-)
create mode 100644 gptqmodel/models/definitions/gemma4.py
create mode 100644 tests/models/test_gemma4_variants.py
create mode 100644 tests/test_gemma4_support.py
diff --git a/gptqmodel/looper/forward_executor.py b/gptqmodel/looper/forward_executor.py
index 384fda36e..d1a2a3807 100644
--- a/gptqmodel/looper/forward_executor.py
+++ b/gptqmodel/looper/forward_executor.py
@@ -239,6 +239,12 @@ def run_single(
additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=exec_device)
additional_inputs["use_cache"] = False
+ additional_inputs = self.looper.gptq_model.prepare_layer_replay_kwargs(
+ layer=module,
+ layer_input=layer_input,
+ additional_inputs=additional_inputs,
+ target_device=exec_device,
+ )
if not preserve_module_devices:
rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True)
@@ -489,6 +495,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
layer_input_kwargs[batch_idx],
attention_masks[batch_idx],
position_ids[batch_idx] if position_ids else None,
+ gptq_model=self.looper.gptq_model,
support_batch_quantize=self.looper.support_batch_quantize,
is_lm_head_module=is_lm_head_module,
need_output=need_outputs,
diff --git a/gptqmodel/looper/stage_inputs_capture.py b/gptqmodel/looper/stage_inputs_capture.py
index 3f9b33b49..d80e626bd 100644
--- a/gptqmodel/looper/stage_inputs_capture.py
+++ b/gptqmodel/looper/stage_inputs_capture.py
@@ -141,11 +141,11 @@ def store_input_hook(module, args, kwargs):
else:
batch_device = data_device
- layer_input: List[torch.Tensor] = []
- if kwargs.get("hidden_states") is not None:
- layer_input.append(move_to(kwargs["hidden_states"], device=batch_device))
- else:
- layer_input.append(move_to(args[0], device=batch_device))
+ layer_input = self.gptq_model.capture_first_layer_positional_inputs(
+ args=args,
+ kwargs=kwargs,
+ batch_device=batch_device,
+ )
layer_inputs.append(layer_input)
@@ -161,6 +161,12 @@ def store_input_hook(module, args, kwargs):
for (k, v) in kwargs.items():
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, device=batch_device)
+ one_kwargs = self.gptq_model.capture_first_layer_input_kwargs(
+ args=args,
+ kwargs=kwargs,
+ batch_device=batch_device,
+ layer_input_kwargs=one_kwargs,
+ )
layer_input_kwargs.append(one_kwargs)
# In normal repeating layer/sbuset early stop happens on the last module forward
diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py
index 3b65fc038..f18eaa227 100644
--- a/gptqmodel/models/auto.py
+++ b/gptqmodel/models/auto.py
@@ -92,6 +92,7 @@
from .definitions.falcon_h1 import FalconH1QModel # noqa: E402
from .definitions.gemma2 import Gemma2QModel # noqa: E402
from .definitions.gemma3 import Gemma3ForConditionalGenerationGPTQ, Gemma3QModel # noqa: E402
+from .definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel # noqa: E402
from .definitions.glm import GlmQModel # noqa: E402
from .definitions.glm4_moe import GLM4MoEGPTQ # noqa: E402
from .definitions.glm4_moe_lite import Glm4MoeLiteQModel # noqa: E402
@@ -210,6 +211,8 @@
"gemma2": Gemma2QModel,
"gemma3_text": Gemma3QModel,
"gemma3": Gemma3ForConditionalGenerationGPTQ,
+ "gemma4_text": Gemma4TextQModel,
+ "gemma4": Gemma4ForConditionalGenerationGPTQ,
"phi": PhiQModel,
"phi3": Phi3QModel,
"phi4mm": Phi4MMGPTQ,
diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py
index 0f3f3e40d..4331f568d 100644
--- a/gptqmodel/models/base.py
+++ b/gptqmodel/models/base.py
@@ -1365,6 +1365,42 @@ def pre_quantize_generate_hook_end(self):
# offload_to_disk(model=self.model, module=self.get_base_modules(model=self.model), disk_path=self.quantize_config.offload_to_disk_path)
pass
+ def capture_first_layer_positional_inputs(
+ self,
+ args: tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ batch_device: torch.device,
+ ) -> List[torch.Tensor]:
+ """Normalize first-layer positional inputs so cached forwards can replay decoder layers directly."""
+
+ if kwargs.get("hidden_states") is not None:
+ return [move_to(kwargs["hidden_states"], device=batch_device)]
+ if args:
+ return [move_to(args[0], device=batch_device)]
+ return []
+
+ def capture_first_layer_input_kwargs(
+ self,
+ args: tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ batch_device: torch.device,
+ layer_input_kwargs: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """Allow model definitions to persist extra first-layer replay metadata during calibration capture."""
+
+ return layer_input_kwargs
+
+ def prepare_layer_replay_kwargs(
+ self,
+ layer: nn.Module,
+ layer_input: List[torch.Tensor],
+ additional_inputs: Dict[str, Any],
+ target_device: torch.device,
+ ) -> Dict[str, Any]:
+ """Allow model definitions to refresh layer-specific kwargs before cached layer replay."""
+
+ return additional_inputs
+
def lm_head_pre_quantize_generate_hook(self, inputs: List[List[torch.tensor]]) -> List[List[torch.tensor]]:
if self.pre_lm_head_norm_module:
norm, _ = get_module_by_name_prefix(self.model, [self.pre_lm_head_norm_module])
diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py
index ebddb8653..ca09e47e8 100644
--- a/gptqmodel/models/definitions/__init__.py
+++ b/gptqmodel/models/definitions/__init__.py
@@ -26,6 +26,7 @@
from .ernie4_5_moe import Ernie4_5_MoeQModel
from .gemma2 import Gemma2QModel
from .gemma3 import Gemma3QModel
+from .gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel
from .glm import GlmQModel
from .gpt2 import GPT2QModel
from .gpt_bigcode import GptBigCodeQModel
diff --git a/gptqmodel/models/definitions/gemma4.py b/gptqmodel/models/definitions/gemma4.py
new file mode 100644
index 000000000..f0494ccb7
--- /dev/null
+++ b/gptqmodel/models/definitions/gemma4.py
@@ -0,0 +1,247 @@
+# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
+# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
+# SPDX-License-Identifier: Apache-2.0
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+
+import torch
+from types import MethodType
+
+from ..base import BaseQModel
+from ...utils.device import get_device
+from ...utils.model import get_module_by_name_prefix, move_to, nested_move_to
+from . import LlamaQModel
+
+
+_GEMMA4_ALL_PER_LAYER_INPUTS = "__gptqmodel_gemma4_all_per_layer_inputs"
+
+
+def _gemma4_module_tree():
+ """Return the Gemma 4 decoder traversal with optional attention and per-layer input modules."""
+
+ return [
+ "model",
+ "layers",
+ "#",
+ {
+ "input_layernorm": ("input_layernorm:!",),
+ "self_attn": (
+ "q_norm:!",
+ "q_proj:0",
+ "k_norm:!",
+ "k_proj:0",
+ "v_norm:!",
+ "v_proj:0",
+ "o_proj:1",
+ ),
+ "post_attention_layernorm": ("post_attention_layernorm:!",),
+ "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",),
+ "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"),
+ "post_feedforward_layernorm": ("post_feedforward_layernorm:!",),
+ "per_layer_input_gate": ("per_layer_input_gate:0",),
+ "post_per_layer_input_norm": ("post_per_layer_input_norm:!",),
+ "per_layer_projection": ("per_layer_projection:1",),
+ },
+ ]
+
+
+def _capture_gemma4_positional_inputs(model_def, args, kwargs, batch_device):
+ """Preserve Gemma 4 per-layer adapter inputs that flow through decoder layers positionally."""
+
+ layer_input = super(type(model_def), model_def).capture_first_layer_positional_inputs(args, kwargs, batch_device)
+ per_layer_input = args[1] if len(args) > 1 else kwargs.get("per_layer_input")
+ if per_layer_input is not None:
+ layer_input.append(move_to(per_layer_input, device=batch_device))
+ return layer_input
+
+
+def _prepare_gemma4_replay_kwargs(model_def, layer, layer_input, additional_inputs, target_device):
+ """Refresh Gemma 4 rotary kwargs per layer so replay follows sliding/full attention boundaries."""
+
+ rotary_path = getattr(model_def, "rotary_embedding", None)
+ if not rotary_path or not layer_input:
+ return additional_inputs
+
+ rotary, _ = get_module_by_name_prefix(model_def.model, [rotary_path])
+ if rotary is None:
+ return additional_inputs
+
+ layer_type = getattr(getattr(layer, "self_attn", None), "layer_type", None)
+ if layer_type is None:
+ return additional_inputs
+
+ hidden_states = layer_input[0]
+ seq_len = hidden_states.shape[1] if hidden_states.dim() >= 2 else hidden_states.shape[0]
+ batch_dim = hidden_states.shape[0] if hidden_states.dim() >= 2 else 1
+
+ position_ids = additional_inputs.get("position_ids")
+ if position_ids is None or position_ids.shape[-1] != seq_len:
+ position_ids = torch.arange(seq_len, device=target_device, dtype=torch.long).unsqueeze(0).expand(batch_dim, -1)
+ additional_inputs["position_ids"] = position_ids
+
+ try:
+ rotary_device = get_device(rotary)
+ except Exception:
+ rotary_device = position_ids.device
+
+ rotary_position_ids = move_to(position_ids, device=rotary_device)
+ rotary_input = torch.empty(1, device=rotary_device, dtype=hidden_states.dtype)
+ additional_inputs["position_embeddings"] = nested_move_to(
+ rotary(rotary_input, rotary_position_ids, layer_type),
+ device=target_device,
+ )
+
+ if len(layer_input) == 1:
+ all_per_layer_inputs = additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None)
+ layer_index = getattr(getattr(layer, "self_attn", None), "layer_idx", None)
+ if all_per_layer_inputs is not None and layer_index is not None:
+ additional_inputs["per_layer_input"] = move_to(
+ all_per_layer_inputs[:, :, layer_index, :],
+ device=target_device,
+ )
+ else:
+ additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None)
+
+ return additional_inputs
+
+
+def _resolve_gemma4_language_model(model_def):
+ """Return the Gemma 4 text stack that owns per-layer input projection state."""
+
+ if hasattr(model_def.model, "model") and hasattr(model_def.model.model, "language_model"):
+ return model_def.model.model.language_model
+ return model_def.model.model
+
+
+def _patch_gemma4_per_layer_input_capture(model_def):
+ """Capture projected per-layer inputs during calibration so later decoder replays can slice them by layer."""
+
+ language_model = _resolve_gemma4_language_model(model_def)
+ if getattr(language_model, "_gptqmodel_project_per_layer_inputs_patched", False):
+ return
+
+ original = language_model.project_per_layer_inputs
+
+ def patched(self, inputs_embeds, per_layer_inputs=None):
+ result = original(inputs_embeds, per_layer_inputs)
+ setattr(self, "_gptqmodel_cached_all_per_layer_inputs", result)
+ return result
+
+ language_model._gptqmodel_original_project_per_layer_inputs = original
+ language_model.project_per_layer_inputs = MethodType(patched, language_model)
+ language_model._gptqmodel_project_per_layer_inputs_patched = True
+
+
+def _restore_gemma4_per_layer_input_capture(model_def):
+ """Restore Gemma 4 per-layer input helpers after calibration capture completes."""
+
+ language_model = _resolve_gemma4_language_model(model_def)
+ original = getattr(language_model, "_gptqmodel_original_project_per_layer_inputs", None)
+ if original is not None:
+ language_model.project_per_layer_inputs = original
+ delattr(language_model, "_gptqmodel_original_project_per_layer_inputs")
+ if hasattr(language_model, "_gptqmodel_project_per_layer_inputs_patched"):
+ delattr(language_model, "_gptqmodel_project_per_layer_inputs_patched")
+ if hasattr(language_model, "_gptqmodel_cached_all_per_layer_inputs"):
+ delattr(language_model, "_gptqmodel_cached_all_per_layer_inputs")
+
+
+class Gemma4TextQModel(LlamaQModel):
+ """Quantization definition for text-only Gemma 4 checkpoints."""
+
+ # Gemma 4 mixes optional KV projections and per-layer residual adapters across variants.
+ layer_modules_strict = False
+ # Gemma 4 input preparation uses per-layer embeddings, so batch quantization stays conservative.
+ support_batch_quantize = False
+ pre_lm_head_norm_module = "model.norm"
+ rotary_embedding = "model.rotary_emb"
+ module_tree = _gemma4_module_tree()
+
+ def capture_first_layer_positional_inputs(self, args, kwargs, batch_device):
+ """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation."""
+
+ return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device)
+
+ def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs):
+ """Persist Gemma 4 per-layer adapter tensors for later decoder replays."""
+
+ layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs)
+ language_model = _resolve_gemma4_language_model(self)
+ all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None)
+ if all_per_layer_inputs is not None:
+ layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device)
+ return layer_input_kwargs
+
+ def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device):
+ """Refresh Gemma 4 layer kwargs during cached replay."""
+
+ return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device)
+
+ def pre_quantize_generate_hook_start(self):
+ _patch_gemma4_per_layer_input_capture(self)
+
+ def pre_quantize_generate_hook_end(self):
+ _restore_gemma4_per_layer_input_capture(self)
+ super().pre_quantize_generate_hook_end()
+
+
+class Gemma4ForConditionalGenerationGPTQ(BaseQModel):
+ """Quantization definition for composite Gemma 4 checkpoints."""
+
+ # Gemma 4 composite checkpoints share the same decoder quirks as the text-only model.
+ layer_modules_strict = False
+ support_batch_quantize = False
+ pre_lm_head_norm_module = "model.language_model.norm"
+ rotary_embedding = "model.language_model.rotary_emb"
+
+ module_tree = [
+ "model",
+ "language_model",
+ "layers",
+ "#",
+ {
+ "input_layernorm": ("input_layernorm:!",),
+ "self_attn": (
+ "q_norm:!",
+ "q_proj:0",
+ "k_norm:!",
+ "k_proj:0",
+ "v_norm:!",
+ "v_proj:0",
+ "o_proj:1",
+ ),
+ "post_attention_layernorm": ("post_attention_layernorm:!",),
+ "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",),
+ "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"),
+ "post_feedforward_layernorm": ("post_feedforward_layernorm:!",),
+ "per_layer_input_gate": ("per_layer_input_gate:0",),
+ "post_per_layer_input_norm": ("post_per_layer_input_norm:!",),
+ "per_layer_projection": ("per_layer_projection:1",),
+ },
+ ]
+
+ def capture_first_layer_positional_inputs(self, args, kwargs, batch_device):
+ """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation."""
+
+ return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device)
+
+ def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs):
+ """Persist Gemma 4 per-layer adapter tensors for later decoder replays."""
+
+ layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs)
+ language_model = _resolve_gemma4_language_model(self)
+ all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None)
+ if all_per_layer_inputs is not None:
+ layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device)
+ return layer_input_kwargs
+
+ def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device):
+ """Refresh Gemma 4 layer kwargs during cached replay."""
+
+ return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device)
+
+ def pre_quantize_generate_hook_start(self):
+ _patch_gemma4_per_layer_input_capture(self)
+
+ def pre_quantize_generate_hook_end(self):
+ _restore_gemma4_per_layer_input_capture(self)
+ super().pre_quantize_generate_hook_end()
diff --git a/gptqmodel/quantization/awq/quantize/scale.py b/gptqmodel/quantization/awq/quantize/scale.py
index 9163e55c9..06e4ccf66 100644
--- a/gptqmodel/quantization/awq/quantize/scale.py
+++ b/gptqmodel/quantization/awq/quantize/scale.py
@@ -27,7 +27,15 @@
from gptqmodel.quantization.awq.utils.utils import get_best_device
+try:
+ from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
+except Exception: # pragma: no cover - older transformers builds do not expose Gemma 4 yet
+ Gemma4RMSNorm = None
+
+
allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm]
+if Gemma4RMSNorm is not None:
+ allowed_norms.append(Gemma4RMSNorm)
allowed_act_fns = [
nn.GELU,
BloomGelu,
diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py
index fc3347c35..480010450 100644
--- a/gptqmodel/quantization/gptq.py
+++ b/gptqmodel/quantization/gptq.py
@@ -964,8 +964,23 @@ def quantize(
if self.qcfg.desc_act and use_hessian:
perm = torch.argsort(torch.diag(self.H), descending=True)
- W = W[:, perm]
- self.H = self.H[perm][:, perm]
+ try:
+ W = W[:, perm]
+ self.H = self.H[perm][:, perm]
+ except RuntimeError as exc:
+ if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower():
+ raise
+
+ log.warn(
+ "Quantization: Module `%s` -> CUDA OOM during Hessian permutation on %s; retrying that module on CPU.",
+ self.name,
+ self.H.device,
+ )
+ cpu_device = torch.device("cpu")
+ perm = perm.to(device=cpu_device)
+ W = W.to(device=cpu_device)[:, perm]
+ self.H = self.H.to(device=cpu_device)[perm][:, perm]
+ self.quantizer.find_params(W, weight=True)
invperm = torch.argsort(perm)
elif self.qcfg.act_group_aware and use_hessian:
@@ -980,17 +995,49 @@ def quantize(
)
del local_values
final_perm = compose_final_perm(local_perms, global_perm, self.qcfg.group_size)
- W = W[:, final_perm]
- self.H = self.H[final_perm][:, final_perm]
+ try:
+ W = W[:, final_perm]
+ self.H = self.H[final_perm][:, final_perm]
+ except RuntimeError as exc:
+ if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower():
+ raise
- Losses = torch.zeros_like(W)
- Q = torch.zeros_like(W)
+ log.warn(
+ "Quantization: Module `%s` -> CUDA OOM during act-group Hessian permutation on %s; retrying that module on CPU.",
+ self.name,
+ self.H.device,
+ )
+ cpu_device = torch.device("cpu")
+ final_perm = final_perm.to(device=cpu_device)
+ W = W.to(device=cpu_device)[:, final_perm]
+ self.H = self.H.to(device=cpu_device)[final_perm][:, final_perm]
+ self.quantizer.find_params(W, weight=True)
if use_hessian:
- Hinv, damp = self.hessian_inverse(self.H)
+ try:
+ Hinv, damp = self.hessian_inverse(self.H)
+ except RuntimeError as exc:
+ if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower():
+ raise
+
+ # Full-attention blocks on very large models can exceed GPU memory during the
+ # dense Hessian inverse; finish that module on CPU instead of aborting the run.
+ log.warn(
+ "Quantization: Module `%s` -> CUDA OOM during Hessian inverse on %s; retrying quantization on CPU.",
+ self.name,
+ self.H.device,
+ )
+ cpu_device = torch.device("cpu")
+ self.H = self.H.to(device=cpu_device)
+ W = W.to(device=cpu_device)
+ self.quantizer.find_params(W, weight=True)
+ Hinv, damp = self.hessian_inverse(self.H)
else:
Hinv, damp = None, 0.0
+ Losses = torch.zeros_like(W)
+ Q = torch.zeros_like(W)
+
# Use simplified loop when mock_quantization is active
if self.qcfg.mock_quantization:
for i1 in range(0, self.columns, blocksize):
diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py
index bfc2412d8..33784fc40 100644
--- a/gptqmodel/utils/looper_helpers.py
+++ b/gptqmodel/utils/looper_helpers.py
@@ -360,6 +360,7 @@ def forward_batch_worker(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor],
*,
+ gptq_model=None,
support_batch_quantize: bool,
is_lm_head_module: bool,
need_output: bool,
@@ -406,6 +407,13 @@ def forward_batch_worker(
# TODO: some models does not honor generate config.use_cache property so we are forced to hack this to false
additional_inputs["use_cache"] = False
+ if gptq_model is not None:
+ additional_inputs = gptq_model.prepare_layer_replay_kwargs(
+ layer=module,
+ layer_input=inputs,
+ additional_inputs=additional_inputs,
+ target_device=module_device,
+ )
module_output = None
kv_next = None
diff --git a/tests/models/test_gemma4_variants.py b/tests/models/test_gemma4_variants.py
new file mode 100644
index 000000000..a29a55d67
--- /dev/null
+++ b/tests/models/test_gemma4_variants.py
@@ -0,0 +1,111 @@
+# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
+# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
+# SPDX-License-Identifier: Apache-2.0
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+
+import os
+import unittest
+
+from huggingface_hub import snapshot_download
+
+from gptqmodel.quantization.config import GcMode, VramStrategy
+from gptqmodel.utils.backend import BACKEND
+
+# Keep Gemma 4 model tests inside the requested PCI bus ordered GPU pool.
+os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
+os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0,1,2,4")
+
+from model_test import ModelTest
+
+
+def _ensure_local_model_dir(local_path: str, repo_id: str) -> str:
+ """Download the checkpoint into the shared local model cache when it is missing."""
+
+ if os.path.isdir(local_path):
+ return local_path
+
+ os.makedirs(local_path, exist_ok=True)
+ snapshot_download(
+ repo_id=repo_id,
+ local_dir=local_path,
+ local_dir_use_symlinks=False,
+ resume_download=True,
+ )
+ return local_path
+
+
+class _Gemma4VariantModelTest(ModelTest):
+ """Shared Gemma 4 model-test harness tuned for fast variant coverage."""
+
+ # Allow the harness to refresh expectations from the current native model when these baselines drift.
+ DISABLE_NATIVE_BASELINE_FALLBACK = False
+ TRUST_REMOTE_CODE = False
+ TORCH_DTYPE = "bfloat16"
+ # The local env does not ship Marlin runtime kernels, so validation reloads must stay on Torch.
+ LOAD_BACKEND = BACKEND.TORCH
+ # Gemma 4 full-attention layers expand to 512-dim heads, which FlashAttention cannot execute.
+ USE_FLASH_ATTN = False
+ # Gemma 4 variants differ most at the tail: KV sharing, full-attention-only layers, and per-layer adapters.
+ MODEL_COMPAT_FAST_LAYER_COUNT = 1
+ MODEL_COMPAT_FAST_LAYER_POSITION = "last"
+ DATASET_SIZE = 128
+ DATASET_CONCAT_SIZE = 1024
+ EVAL_BATCH_SIZE = 4
+ EVAL_TASKS_SLOW = {
+ "arc_challenge": {
+ "chat_template": True,
+ "acc": {"value": 0.30, "floor_pct": 0.35, "ceil_pct": 1.0},
+ "acc_norm": {"value": 0.33, "floor_pct": 0.35, "ceil_pct": 1.0},
+ },
+ }
+ EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW)
+ HF_MODEL_ID = None
+
+ @classmethod
+ def setUpClass(cls):
+ if isinstance(getattr(cls, "NATIVE_MODEL_ID", None), str):
+ model_path = cls.NATIVE_MODEL_ID.strip()
+ if os.path.isabs(model_path) and not os.path.isdir(model_path):
+ if not cls.HF_MODEL_ID:
+ raise unittest.SkipTest(f"Model path missing and no HF repo configured: {model_path}")
+ cls.NATIVE_MODEL_ID = _ensure_local_model_dir(model_path, cls.HF_MODEL_ID)
+ super().setUpClass()
+
+
+class TestGemma4E2B(_Gemma4VariantModelTest):
+ NATIVE_MODEL_ID = "/monster/data/model/gemma-4-E2B"
+ HF_MODEL_ID = "google/gemma-4-e2b-it"
+ PIN_CUDA_DEVICE = 0
+ EVAL_BATCH_SIZE = 8
+
+ def test_gemma4_e2b(self):
+ self.quant_lm_eval()
+
+
+class TestGemma4E4BIt(_Gemma4VariantModelTest):
+ NATIVE_MODEL_ID = "/monster/data/model/gemma-4-E4B-it"
+ HF_MODEL_ID = "google/gemma-4-e4b-it"
+ PIN_CUDA_DEVICE = 1
+ EVAL_BATCH_SIZE = 4
+
+ def test_gemma4_e4b_it(self):
+ self.quant_lm_eval()
+
+
+class TestGemma431BIt(_Gemma4VariantModelTest):
+ NATIVE_MODEL_ID = "/monster/data/model/gemma-4-31B-it"
+ HF_MODEL_ID = "google/gemma-4-31b-it"
+ # Visible index 3 maps to physical GPU 4 under CUDA_VISIBLE_DEVICES=0,1,2,4.
+ PIN_CUDA_DEVICE = 3
+ EVAL_BATCH_SIZE = 1
+ VRAM_STRATEGY = VramStrategy.BALANCED
+
+ def _build_quantize_config(self):
+ quantize_config = super()._build_quantize_config()
+ # 31B full-attention q_proj hits a very large Hessian inverse; flush prior finalizers before the next stage.
+ quantize_config.wait_for_submodule_finalizers = True
+ quantize_config.gc_mode = GcMode.ON_STAGE_END
+ return quantize_config
+
+ def test_gemma4_31b_it(self):
+ self.quant_lm_eval()
diff --git a/tests/test_gemma4_support.py b/tests/test_gemma4_support.py
new file mode 100644
index 000000000..2092ba1d6
--- /dev/null
+++ b/tests/test_gemma4_support.py
@@ -0,0 +1,184 @@
+from types import SimpleNamespace
+
+import pytest
+import torch
+from torch import nn
+from transformers import AutoConfig
+
+from gptqmodel.models import auto
+from gptqmodel.models.definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel
+
+
+GEMMA4_VARIANTS = [
+ "/monster/data/model/gemma-4-E2B",
+ "/monster/data/model/gemma-4-E4B-it",
+ "/monster/data/model/gemma-4-31B-it",
+]
+
+
+@pytest.mark.parametrize("model_path", GEMMA4_VARIANTS)
+def test_gemma4_local_variants_select_multimodal_definition(model_path):
+ config = AutoConfig.from_pretrained(model_path)
+
+ assert config.model_type == "gemma4"
+ assert auto.check_and_get_model_definition(model_path) is Gemma4ForConditionalGenerationGPTQ
+
+
+def test_gemma4_text_model_type_selects_text_definition(monkeypatch):
+ fake_config = SimpleNamespace(model_type="gemma4_text")
+
+ monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code)
+ monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config)
+
+ assert auto.check_and_get_model_definition("/tmp/gemma4-text") is Gemma4TextQModel
+
+
+def test_gemma4_module_tree_keeps_optional_variant_paths_non_strict():
+ layer_modules = Gemma4TextQModel.simple_layer_modules(
+ model_config=SimpleNamespace(),
+ quantize_config=SimpleNamespace(dynamic=None),
+ )
+ flat_modules = {name for block in layer_modules for name in block}
+
+ assert Gemma4TextQModel.layer_modules_strict is False
+ assert "self_attn.q_proj" in flat_modules
+ assert "self_attn.k_proj" in flat_modules
+ assert "self_attn.v_proj" in flat_modules
+ assert "self_attn.o_proj" in flat_modules
+ assert "per_layer_input_gate" in flat_modules
+ assert "per_layer_projection" in flat_modules
+
+
+def test_gemma4_multimodal_base_modules_include_per_layer_helpers():
+ class _LanguageModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = nn.ModuleList([nn.Identity()])
+ self.embed_tokens = nn.Embedding(4, 4)
+ self.embed_tokens_per_layer = nn.Embedding(4, 4)
+ self.per_layer_model_projection = nn.Linear(4, 4, bias=False)
+ self.per_layer_projection_norm = nn.LayerNorm(4)
+ self.norm = nn.LayerNorm(4)
+ self.rotary_emb = nn.Identity()
+
+ class _Gemma4Core(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.language_model = _LanguageModel()
+ self.vision_tower = nn.Identity()
+ self.embed_vision = nn.Identity()
+ self.audio_tower = nn.Identity()
+ self.embed_audio = nn.Identity()
+
+ class _Gemma4Wrapper(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = _Gemma4Core()
+ self.lm_head = nn.Linear(4, 4, bias=False)
+
+ model = _Gemma4Wrapper()
+ base_modules = set(Gemma4ForConditionalGenerationGPTQ.get_base_modules(model))
+
+ assert Gemma4ForConditionalGenerationGPTQ.extract_layers_node() == ["model.language_model.layers"]
+ assert "model.vision_tower" in base_modules
+ assert "model.embed_vision" in base_modules
+ assert "model.audio_tower" in base_modules
+ assert "model.embed_audio" in base_modules
+ assert "model.language_model.embed_tokens" in base_modules
+ assert "model.language_model.embed_tokens_per_layer" in base_modules
+ assert "model.language_model.per_layer_model_projection" in base_modules
+ assert "model.language_model.per_layer_projection_norm" in base_modules
+
+
+def test_gemma4_capture_preserves_per_layer_input():
+ model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ)
+ hidden_states = torch.randn(1, 4, 8)
+ per_layer_input = torch.randn(1, 4, 2)
+
+ captured = model_def.capture_first_layer_positional_inputs(
+ args=(hidden_states, per_layer_input),
+ kwargs={},
+ batch_device=torch.device("cpu"),
+ )
+
+ assert len(captured) == 2
+ assert torch.equal(captured[0], hidden_states)
+ assert torch.equal(captured[1], per_layer_input)
+
+
+def test_gemma4_replay_kwargs_refresh_position_embeddings():
+ class _FakeRotary(nn.Module):
+ def forward(self, x, position_ids, layer_type=None):
+ marker = 7.0 if layer_type == "full_attention" else 3.0
+ shape = (position_ids.shape[0], position_ids.shape[1], 1)
+ value = torch.full(shape, marker, dtype=x.dtype, device=x.device)
+ return value, value + 1
+
+ class _LanguageModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.rotary_emb = _FakeRotary()
+
+ class _Gemma4Core(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.language_model = _LanguageModel()
+
+ class _Gemma4Wrapper(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = _Gemma4Core()
+
+ model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ)
+ nn.Module.__init__(model_def)
+ model_def.model = _Gemma4Wrapper()
+
+ layer = SimpleNamespace(self_attn=SimpleNamespace(layer_type="full_attention"))
+ hidden_states = torch.randn(1, 4, 8)
+ refreshed = model_def.prepare_layer_replay_kwargs(
+ layer=layer,
+ layer_input=[hidden_states],
+ additional_inputs={
+ "position_ids": torch.arange(4).unsqueeze(0),
+ "position_embeddings": ("stale",),
+ },
+ target_device=torch.device("cpu"),
+ )
+
+ cos, sin = refreshed["position_embeddings"]
+ assert cos.shape == (1, 4, 1)
+ assert sin.shape == (1, 4, 1)
+ assert torch.all(cos == 7)
+ assert torch.all(sin == 8)
+
+
+def test_gemma4_capture_kwargs_preserve_all_per_layer_inputs():
+ class _LanguageModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.rotary_emb = nn.Identity()
+ self._gptqmodel_cached_all_per_layer_inputs = torch.randn(1, 4, 3, 2)
+
+ class _Gemma4Core(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.language_model = _LanguageModel()
+
+ class _Gemma4Wrapper(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = _Gemma4Core()
+
+ model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ)
+ nn.Module.__init__(model_def)
+ model_def.model = _Gemma4Wrapper()
+
+ captured = model_def.capture_first_layer_input_kwargs(
+ args=(),
+ kwargs={},
+ batch_device=torch.device("cpu"),
+ layer_input_kwargs={},
+ )
+
+ assert "__gptqmodel_gemma4_all_per_layer_inputs" in captured
+ assert captured["__gptqmodel_gemma4_all_per_layer_inputs"].shape == (1, 4, 3, 2)
From e0c62a1eaa24f31ebcbda1d49ffd225abb7ae45d Mon Sep 17 00:00:00 2001
From: Qubitium-ModelCloud
Date: Fri, 3 Apr 2026 06:52:05 +0800
Subject: [PATCH 2/3] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index c430854e4..90c4b4723 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@
## Latest News
-* 04/02/2026 [6.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.0): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support.
+* 04/03/2026 [6.0.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.2): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `Gemma4`, `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support.
* 03/19/2026 [5.8.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.8.0): ✨HF Transformers 5.3.0 support with auto-defusing of `fused` models via pypi pkg: [Defuser](https://github.com/ModelCloud/Defuser). Qwen 3.5 family support added. New fast HF `cpu` kernels for GPTQ/AWQ added. Experimental INT8 `cpu` kernel added for GPTQ.
* 03/09/2026 [main]: ✨Qwen 3.5 MoE model support added. New HF Kernel support added for AWQ.
HF Kernel for both gptq/awq are now used by default for cpu devices for best performance. New INT8 kernel ported from Intel for gptq.
From 9db3fa95b141f1ac7f45f52fd57236f396a8f676 Mon Sep 17 00:00:00 2001
From: Qubitium-ModelCloud
Date: Fri, 3 Apr 2026 06:52:27 +0800
Subject: [PATCH 3/3] Bump version from 6.0.0 to 6.0.2
---
gptqmodel/version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gptqmodel/version.py b/gptqmodel/version.py
index b6d1a8a49..f7a6e896e 100644
--- a/gptqmodel/version.py
+++ b/gptqmodel/version.py
@@ -7,4 +7,4 @@
# even minor versions are release
# 5.2.0 => release, 5.1.0 => devel
# micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release
-__version__ = "6.0.0"
+__version__ = "6.0.2"