From 7fab0df8eaf5834b36bb2e3623fad644bf53eda4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 2 Apr 2026 22:39:19 +0000 Subject: [PATCH 1/3] add gemma4 --- gptqmodel/looper/forward_executor.py | 7 + gptqmodel/looper/stage_inputs_capture.py | 16 +- gptqmodel/models/auto.py | 3 + gptqmodel/models/base.py | 36 +++ gptqmodel/models/definitions/__init__.py | 1 + gptqmodel/models/definitions/gemma4.py | 247 +++++++++++++++++++ gptqmodel/quantization/awq/quantize/scale.py | 8 + gptqmodel/quantization/gptq.py | 61 ++++- gptqmodel/utils/looper_helpers.py | 8 + tests/models/test_gemma4_variants.py | 111 +++++++++ tests/test_gemma4_support.py | 184 ++++++++++++++ 11 files changed, 670 insertions(+), 12 deletions(-) create mode 100644 gptqmodel/models/definitions/gemma4.py create mode 100644 tests/models/test_gemma4_variants.py create mode 100644 tests/test_gemma4_support.py diff --git a/gptqmodel/looper/forward_executor.py b/gptqmodel/looper/forward_executor.py index 384fda36e..d1a2a3807 100644 --- a/gptqmodel/looper/forward_executor.py +++ b/gptqmodel/looper/forward_executor.py @@ -239,6 +239,12 @@ def run_single( additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=exec_device) additional_inputs["use_cache"] = False + additional_inputs = self.looper.gptq_model.prepare_layer_replay_kwargs( + layer=module, + layer_input=layer_input, + additional_inputs=additional_inputs, + target_device=exec_device, + ) if not preserve_module_devices: rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True) @@ -489,6 +495,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> layer_input_kwargs[batch_idx], attention_masks[batch_idx], position_ids[batch_idx] if position_ids else None, + gptq_model=self.looper.gptq_model, support_batch_quantize=self.looper.support_batch_quantize, is_lm_head_module=is_lm_head_module, need_output=need_outputs, diff --git a/gptqmodel/looper/stage_inputs_capture.py b/gptqmodel/looper/stage_inputs_capture.py index 3f9b33b49..d80e626bd 100644 --- a/gptqmodel/looper/stage_inputs_capture.py +++ b/gptqmodel/looper/stage_inputs_capture.py @@ -141,11 +141,11 @@ def store_input_hook(module, args, kwargs): else: batch_device = data_device - layer_input: List[torch.Tensor] = [] - if kwargs.get("hidden_states") is not None: - layer_input.append(move_to(kwargs["hidden_states"], device=batch_device)) - else: - layer_input.append(move_to(args[0], device=batch_device)) + layer_input = self.gptq_model.capture_first_layer_positional_inputs( + args=args, + kwargs=kwargs, + batch_device=batch_device, + ) layer_inputs.append(layer_input) @@ -161,6 +161,12 @@ def store_input_hook(module, args, kwargs): for (k, v) in kwargs.items(): if k not in ["hidden_states", "attention_mask", "position_ids"]: one_kwargs[k] = nested_move_to(v, device=batch_device) + one_kwargs = self.gptq_model.capture_first_layer_input_kwargs( + args=args, + kwargs=kwargs, + batch_device=batch_device, + layer_input_kwargs=one_kwargs, + ) layer_input_kwargs.append(one_kwargs) # In normal repeating layer/sbuset early stop happens on the last module forward diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 3b65fc038..f18eaa227 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -92,6 +92,7 @@ from .definitions.falcon_h1 import FalconH1QModel # noqa: E402 from .definitions.gemma2 import Gemma2QModel # noqa: E402 from .definitions.gemma3 import Gemma3ForConditionalGenerationGPTQ, Gemma3QModel # noqa: E402 +from .definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel # noqa: E402 from .definitions.glm import GlmQModel # noqa: E402 from .definitions.glm4_moe import GLM4MoEGPTQ # noqa: E402 from .definitions.glm4_moe_lite import Glm4MoeLiteQModel # noqa: E402 @@ -210,6 +211,8 @@ "gemma2": Gemma2QModel, "gemma3_text": Gemma3QModel, "gemma3": Gemma3ForConditionalGenerationGPTQ, + "gemma4_text": Gemma4TextQModel, + "gemma4": Gemma4ForConditionalGenerationGPTQ, "phi": PhiQModel, "phi3": Phi3QModel, "phi4mm": Phi4MMGPTQ, diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 0f3f3e40d..4331f568d 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -1365,6 +1365,42 @@ def pre_quantize_generate_hook_end(self): # offload_to_disk(model=self.model, module=self.get_base_modules(model=self.model), disk_path=self.quantize_config.offload_to_disk_path) pass + def capture_first_layer_positional_inputs( + self, + args: tuple[Any, ...], + kwargs: Dict[str, Any], + batch_device: torch.device, + ) -> List[torch.Tensor]: + """Normalize first-layer positional inputs so cached forwards can replay decoder layers directly.""" + + if kwargs.get("hidden_states") is not None: + return [move_to(kwargs["hidden_states"], device=batch_device)] + if args: + return [move_to(args[0], device=batch_device)] + return [] + + def capture_first_layer_input_kwargs( + self, + args: tuple[Any, ...], + kwargs: Dict[str, Any], + batch_device: torch.device, + layer_input_kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Allow model definitions to persist extra first-layer replay metadata during calibration capture.""" + + return layer_input_kwargs + + def prepare_layer_replay_kwargs( + self, + layer: nn.Module, + layer_input: List[torch.Tensor], + additional_inputs: Dict[str, Any], + target_device: torch.device, + ) -> Dict[str, Any]: + """Allow model definitions to refresh layer-specific kwargs before cached layer replay.""" + + return additional_inputs + def lm_head_pre_quantize_generate_hook(self, inputs: List[List[torch.tensor]]) -> List[List[torch.tensor]]: if self.pre_lm_head_norm_module: norm, _ = get_module_by_name_prefix(self.model, [self.pre_lm_head_norm_module]) diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index ebddb8653..ca09e47e8 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -26,6 +26,7 @@ from .ernie4_5_moe import Ernie4_5_MoeQModel from .gemma2 import Gemma2QModel from .gemma3 import Gemma3QModel +from .gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel from .glm import GlmQModel from .gpt2 import GPT2QModel from .gpt_bigcode import GptBigCodeQModel diff --git a/gptqmodel/models/definitions/gemma4.py b/gptqmodel/models/definitions/gemma4.py new file mode 100644 index 000000000..f0494ccb7 --- /dev/null +++ b/gptqmodel/models/definitions/gemma4.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch +from types import MethodType + +from ..base import BaseQModel +from ...utils.device import get_device +from ...utils.model import get_module_by_name_prefix, move_to, nested_move_to +from . import LlamaQModel + + +_GEMMA4_ALL_PER_LAYER_INPUTS = "__gptqmodel_gemma4_all_per_layer_inputs" + + +def _gemma4_module_tree(): + """Return the Gemma 4 decoder traversal with optional attention and per-layer input modules.""" + + return [ + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ( + "q_norm:!", + "q_proj:0", + "k_norm:!", + "k_proj:0", + "v_norm:!", + "v_proj:0", + "o_proj:1", + ), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",), + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "post_feedforward_layernorm": ("post_feedforward_layernorm:!",), + "per_layer_input_gate": ("per_layer_input_gate:0",), + "post_per_layer_input_norm": ("post_per_layer_input_norm:!",), + "per_layer_projection": ("per_layer_projection:1",), + }, + ] + + +def _capture_gemma4_positional_inputs(model_def, args, kwargs, batch_device): + """Preserve Gemma 4 per-layer adapter inputs that flow through decoder layers positionally.""" + + layer_input = super(type(model_def), model_def).capture_first_layer_positional_inputs(args, kwargs, batch_device) + per_layer_input = args[1] if len(args) > 1 else kwargs.get("per_layer_input") + if per_layer_input is not None: + layer_input.append(move_to(per_layer_input, device=batch_device)) + return layer_input + + +def _prepare_gemma4_replay_kwargs(model_def, layer, layer_input, additional_inputs, target_device): + """Refresh Gemma 4 rotary kwargs per layer so replay follows sliding/full attention boundaries.""" + + rotary_path = getattr(model_def, "rotary_embedding", None) + if not rotary_path or not layer_input: + return additional_inputs + + rotary, _ = get_module_by_name_prefix(model_def.model, [rotary_path]) + if rotary is None: + return additional_inputs + + layer_type = getattr(getattr(layer, "self_attn", None), "layer_type", None) + if layer_type is None: + return additional_inputs + + hidden_states = layer_input[0] + seq_len = hidden_states.shape[1] if hidden_states.dim() >= 2 else hidden_states.shape[0] + batch_dim = hidden_states.shape[0] if hidden_states.dim() >= 2 else 1 + + position_ids = additional_inputs.get("position_ids") + if position_ids is None or position_ids.shape[-1] != seq_len: + position_ids = torch.arange(seq_len, device=target_device, dtype=torch.long).unsqueeze(0).expand(batch_dim, -1) + additional_inputs["position_ids"] = position_ids + + try: + rotary_device = get_device(rotary) + except Exception: + rotary_device = position_ids.device + + rotary_position_ids = move_to(position_ids, device=rotary_device) + rotary_input = torch.empty(1, device=rotary_device, dtype=hidden_states.dtype) + additional_inputs["position_embeddings"] = nested_move_to( + rotary(rotary_input, rotary_position_ids, layer_type), + device=target_device, + ) + + if len(layer_input) == 1: + all_per_layer_inputs = additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None) + layer_index = getattr(getattr(layer, "self_attn", None), "layer_idx", None) + if all_per_layer_inputs is not None and layer_index is not None: + additional_inputs["per_layer_input"] = move_to( + all_per_layer_inputs[:, :, layer_index, :], + device=target_device, + ) + else: + additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None) + + return additional_inputs + + +def _resolve_gemma4_language_model(model_def): + """Return the Gemma 4 text stack that owns per-layer input projection state.""" + + if hasattr(model_def.model, "model") and hasattr(model_def.model.model, "language_model"): + return model_def.model.model.language_model + return model_def.model.model + + +def _patch_gemma4_per_layer_input_capture(model_def): + """Capture projected per-layer inputs during calibration so later decoder replays can slice them by layer.""" + + language_model = _resolve_gemma4_language_model(model_def) + if getattr(language_model, "_gptqmodel_project_per_layer_inputs_patched", False): + return + + original = language_model.project_per_layer_inputs + + def patched(self, inputs_embeds, per_layer_inputs=None): + result = original(inputs_embeds, per_layer_inputs) + setattr(self, "_gptqmodel_cached_all_per_layer_inputs", result) + return result + + language_model._gptqmodel_original_project_per_layer_inputs = original + language_model.project_per_layer_inputs = MethodType(patched, language_model) + language_model._gptqmodel_project_per_layer_inputs_patched = True + + +def _restore_gemma4_per_layer_input_capture(model_def): + """Restore Gemma 4 per-layer input helpers after calibration capture completes.""" + + language_model = _resolve_gemma4_language_model(model_def) + original = getattr(language_model, "_gptqmodel_original_project_per_layer_inputs", None) + if original is not None: + language_model.project_per_layer_inputs = original + delattr(language_model, "_gptqmodel_original_project_per_layer_inputs") + if hasattr(language_model, "_gptqmodel_project_per_layer_inputs_patched"): + delattr(language_model, "_gptqmodel_project_per_layer_inputs_patched") + if hasattr(language_model, "_gptqmodel_cached_all_per_layer_inputs"): + delattr(language_model, "_gptqmodel_cached_all_per_layer_inputs") + + +class Gemma4TextQModel(LlamaQModel): + """Quantization definition for text-only Gemma 4 checkpoints.""" + + # Gemma 4 mixes optional KV projections and per-layer residual adapters across variants. + layer_modules_strict = False + # Gemma 4 input preparation uses per-layer embeddings, so batch quantization stays conservative. + support_batch_quantize = False + pre_lm_head_norm_module = "model.norm" + rotary_embedding = "model.rotary_emb" + module_tree = _gemma4_module_tree() + + def capture_first_layer_positional_inputs(self, args, kwargs, batch_device): + """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation.""" + + return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device) + + def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs): + """Persist Gemma 4 per-layer adapter tensors for later decoder replays.""" + + layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs) + language_model = _resolve_gemma4_language_model(self) + all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None) + if all_per_layer_inputs is not None: + layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device) + return layer_input_kwargs + + def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device): + """Refresh Gemma 4 layer kwargs during cached replay.""" + + return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device) + + def pre_quantize_generate_hook_start(self): + _patch_gemma4_per_layer_input_capture(self) + + def pre_quantize_generate_hook_end(self): + _restore_gemma4_per_layer_input_capture(self) + super().pre_quantize_generate_hook_end() + + +class Gemma4ForConditionalGenerationGPTQ(BaseQModel): + """Quantization definition for composite Gemma 4 checkpoints.""" + + # Gemma 4 composite checkpoints share the same decoder quirks as the text-only model. + layer_modules_strict = False + support_batch_quantize = False + pre_lm_head_norm_module = "model.language_model.norm" + rotary_embedding = "model.language_model.rotary_emb" + + module_tree = [ + "model", + "language_model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ( + "q_norm:!", + "q_proj:0", + "k_norm:!", + "k_proj:0", + "v_norm:!", + "v_proj:0", + "o_proj:1", + ), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",), + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "post_feedforward_layernorm": ("post_feedforward_layernorm:!",), + "per_layer_input_gate": ("per_layer_input_gate:0",), + "post_per_layer_input_norm": ("post_per_layer_input_norm:!",), + "per_layer_projection": ("per_layer_projection:1",), + }, + ] + + def capture_first_layer_positional_inputs(self, args, kwargs, batch_device): + """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation.""" + + return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device) + + def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs): + """Persist Gemma 4 per-layer adapter tensors for later decoder replays.""" + + layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs) + language_model = _resolve_gemma4_language_model(self) + all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None) + if all_per_layer_inputs is not None: + layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device) + return layer_input_kwargs + + def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device): + """Refresh Gemma 4 layer kwargs during cached replay.""" + + return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device) + + def pre_quantize_generate_hook_start(self): + _patch_gemma4_per_layer_input_capture(self) + + def pre_quantize_generate_hook_end(self): + _restore_gemma4_per_layer_input_capture(self) + super().pre_quantize_generate_hook_end() diff --git a/gptqmodel/quantization/awq/quantize/scale.py b/gptqmodel/quantization/awq/quantize/scale.py index 9163e55c9..06e4ccf66 100644 --- a/gptqmodel/quantization/awq/quantize/scale.py +++ b/gptqmodel/quantization/awq/quantize/scale.py @@ -27,7 +27,15 @@ from gptqmodel.quantization.awq.utils.utils import get_best_device +try: + from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm +except Exception: # pragma: no cover - older transformers builds do not expose Gemma 4 yet + Gemma4RMSNorm = None + + allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm] +if Gemma4RMSNorm is not None: + allowed_norms.append(Gemma4RMSNorm) allowed_act_fns = [ nn.GELU, BloomGelu, diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index fc3347c35..480010450 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -964,8 +964,23 @@ def quantize( if self.qcfg.desc_act and use_hessian: perm = torch.argsort(torch.diag(self.H), descending=True) - W = W[:, perm] - self.H = self.H[perm][:, perm] + try: + W = W[:, perm] + self.H = self.H[perm][:, perm] + except RuntimeError as exc: + if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): + raise + + log.warn( + "Quantization: Module `%s` -> CUDA OOM during Hessian permutation on %s; retrying that module on CPU.", + self.name, + self.H.device, + ) + cpu_device = torch.device("cpu") + perm = perm.to(device=cpu_device) + W = W.to(device=cpu_device)[:, perm] + self.H = self.H.to(device=cpu_device)[perm][:, perm] + self.quantizer.find_params(W, weight=True) invperm = torch.argsort(perm) elif self.qcfg.act_group_aware and use_hessian: @@ -980,17 +995,49 @@ def quantize( ) del local_values final_perm = compose_final_perm(local_perms, global_perm, self.qcfg.group_size) - W = W[:, final_perm] - self.H = self.H[final_perm][:, final_perm] + try: + W = W[:, final_perm] + self.H = self.H[final_perm][:, final_perm] + except RuntimeError as exc: + if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): + raise - Losses = torch.zeros_like(W) - Q = torch.zeros_like(W) + log.warn( + "Quantization: Module `%s` -> CUDA OOM during act-group Hessian permutation on %s; retrying that module on CPU.", + self.name, + self.H.device, + ) + cpu_device = torch.device("cpu") + final_perm = final_perm.to(device=cpu_device) + W = W.to(device=cpu_device)[:, final_perm] + self.H = self.H.to(device=cpu_device)[final_perm][:, final_perm] + self.quantizer.find_params(W, weight=True) if use_hessian: - Hinv, damp = self.hessian_inverse(self.H) + try: + Hinv, damp = self.hessian_inverse(self.H) + except RuntimeError as exc: + if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): + raise + + # Full-attention blocks on very large models can exceed GPU memory during the + # dense Hessian inverse; finish that module on CPU instead of aborting the run. + log.warn( + "Quantization: Module `%s` -> CUDA OOM during Hessian inverse on %s; retrying quantization on CPU.", + self.name, + self.H.device, + ) + cpu_device = torch.device("cpu") + self.H = self.H.to(device=cpu_device) + W = W.to(device=cpu_device) + self.quantizer.find_params(W, weight=True) + Hinv, damp = self.hessian_inverse(self.H) else: Hinv, damp = None, 0.0 + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + # Use simplified loop when mock_quantization is active if self.qcfg.mock_quantization: for i1 in range(0, self.columns, blocksize): diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index bfc2412d8..33784fc40 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -360,6 +360,7 @@ def forward_batch_worker( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.Tensor], *, + gptq_model=None, support_batch_quantize: bool, is_lm_head_module: bool, need_output: bool, @@ -406,6 +407,13 @@ def forward_batch_worker( # TODO: some models does not honor generate config.use_cache property so we are forced to hack this to false additional_inputs["use_cache"] = False + if gptq_model is not None: + additional_inputs = gptq_model.prepare_layer_replay_kwargs( + layer=module, + layer_input=inputs, + additional_inputs=additional_inputs, + target_device=module_device, + ) module_output = None kv_next = None diff --git a/tests/models/test_gemma4_variants.py b/tests/models/test_gemma4_variants.py new file mode 100644 index 000000000..a29a55d67 --- /dev/null +++ b/tests/models/test_gemma4_variants.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import os +import unittest + +from huggingface_hub import snapshot_download + +from gptqmodel.quantization.config import GcMode, VramStrategy +from gptqmodel.utils.backend import BACKEND + +# Keep Gemma 4 model tests inside the requested PCI bus ordered GPU pool. +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") +os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0,1,2,4") + +from model_test import ModelTest + + +def _ensure_local_model_dir(local_path: str, repo_id: str) -> str: + """Download the checkpoint into the shared local model cache when it is missing.""" + + if os.path.isdir(local_path): + return local_path + + os.makedirs(local_path, exist_ok=True) + snapshot_download( + repo_id=repo_id, + local_dir=local_path, + local_dir_use_symlinks=False, + resume_download=True, + ) + return local_path + + +class _Gemma4VariantModelTest(ModelTest): + """Shared Gemma 4 model-test harness tuned for fast variant coverage.""" + + # Allow the harness to refresh expectations from the current native model when these baselines drift. + DISABLE_NATIVE_BASELINE_FALLBACK = False + TRUST_REMOTE_CODE = False + TORCH_DTYPE = "bfloat16" + # The local env does not ship Marlin runtime kernels, so validation reloads must stay on Torch. + LOAD_BACKEND = BACKEND.TORCH + # Gemma 4 full-attention layers expand to 512-dim heads, which FlashAttention cannot execute. + USE_FLASH_ATTN = False + # Gemma 4 variants differ most at the tail: KV sharing, full-attention-only layers, and per-layer adapters. + MODEL_COMPAT_FAST_LAYER_COUNT = 1 + MODEL_COMPAT_FAST_LAYER_POSITION = "last" + DATASET_SIZE = 128 + DATASET_CONCAT_SIZE = 1024 + EVAL_BATCH_SIZE = 4 + EVAL_TASKS_SLOW = { + "arc_challenge": { + "chat_template": True, + "acc": {"value": 0.30, "floor_pct": 0.35, "ceil_pct": 1.0}, + "acc_norm": {"value": 0.33, "floor_pct": 0.35, "ceil_pct": 1.0}, + }, + } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + HF_MODEL_ID = None + + @classmethod + def setUpClass(cls): + if isinstance(getattr(cls, "NATIVE_MODEL_ID", None), str): + model_path = cls.NATIVE_MODEL_ID.strip() + if os.path.isabs(model_path) and not os.path.isdir(model_path): + if not cls.HF_MODEL_ID: + raise unittest.SkipTest(f"Model path missing and no HF repo configured: {model_path}") + cls.NATIVE_MODEL_ID = _ensure_local_model_dir(model_path, cls.HF_MODEL_ID) + super().setUpClass() + + +class TestGemma4E2B(_Gemma4VariantModelTest): + NATIVE_MODEL_ID = "/monster/data/model/gemma-4-E2B" + HF_MODEL_ID = "google/gemma-4-e2b-it" + PIN_CUDA_DEVICE = 0 + EVAL_BATCH_SIZE = 8 + + def test_gemma4_e2b(self): + self.quant_lm_eval() + + +class TestGemma4E4BIt(_Gemma4VariantModelTest): + NATIVE_MODEL_ID = "/monster/data/model/gemma-4-E4B-it" + HF_MODEL_ID = "google/gemma-4-e4b-it" + PIN_CUDA_DEVICE = 1 + EVAL_BATCH_SIZE = 4 + + def test_gemma4_e4b_it(self): + self.quant_lm_eval() + + +class TestGemma431BIt(_Gemma4VariantModelTest): + NATIVE_MODEL_ID = "/monster/data/model/gemma-4-31B-it" + HF_MODEL_ID = "google/gemma-4-31b-it" + # Visible index 3 maps to physical GPU 4 under CUDA_VISIBLE_DEVICES=0,1,2,4. + PIN_CUDA_DEVICE = 3 + EVAL_BATCH_SIZE = 1 + VRAM_STRATEGY = VramStrategy.BALANCED + + def _build_quantize_config(self): + quantize_config = super()._build_quantize_config() + # 31B full-attention q_proj hits a very large Hessian inverse; flush prior finalizers before the next stage. + quantize_config.wait_for_submodule_finalizers = True + quantize_config.gc_mode = GcMode.ON_STAGE_END + return quantize_config + + def test_gemma4_31b_it(self): + self.quant_lm_eval() diff --git a/tests/test_gemma4_support.py b/tests/test_gemma4_support.py new file mode 100644 index 000000000..2092ba1d6 --- /dev/null +++ b/tests/test_gemma4_support.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +import pytest +import torch +from torch import nn +from transformers import AutoConfig + +from gptqmodel.models import auto +from gptqmodel.models.definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel + + +GEMMA4_VARIANTS = [ + "/monster/data/model/gemma-4-E2B", + "/monster/data/model/gemma-4-E4B-it", + "/monster/data/model/gemma-4-31B-it", +] + + +@pytest.mark.parametrize("model_path", GEMMA4_VARIANTS) +def test_gemma4_local_variants_select_multimodal_definition(model_path): + config = AutoConfig.from_pretrained(model_path) + + assert config.model_type == "gemma4" + assert auto.check_and_get_model_definition(model_path) is Gemma4ForConditionalGenerationGPTQ + + +def test_gemma4_text_model_type_selects_text_definition(monkeypatch): + fake_config = SimpleNamespace(model_type="gemma4_text") + + monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code) + monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config) + + assert auto.check_and_get_model_definition("/tmp/gemma4-text") is Gemma4TextQModel + + +def test_gemma4_module_tree_keeps_optional_variant_paths_non_strict(): + layer_modules = Gemma4TextQModel.simple_layer_modules( + model_config=SimpleNamespace(), + quantize_config=SimpleNamespace(dynamic=None), + ) + flat_modules = {name for block in layer_modules for name in block} + + assert Gemma4TextQModel.layer_modules_strict is False + assert "self_attn.q_proj" in flat_modules + assert "self_attn.k_proj" in flat_modules + assert "self_attn.v_proj" in flat_modules + assert "self_attn.o_proj" in flat_modules + assert "per_layer_input_gate" in flat_modules + assert "per_layer_projection" in flat_modules + + +def test_gemma4_multimodal_base_modules_include_per_layer_helpers(): + class _LanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([nn.Identity()]) + self.embed_tokens = nn.Embedding(4, 4) + self.embed_tokens_per_layer = nn.Embedding(4, 4) + self.per_layer_model_projection = nn.Linear(4, 4, bias=False) + self.per_layer_projection_norm = nn.LayerNorm(4) + self.norm = nn.LayerNorm(4) + self.rotary_emb = nn.Identity() + + class _Gemma4Core(nn.Module): + def __init__(self): + super().__init__() + self.language_model = _LanguageModel() + self.vision_tower = nn.Identity() + self.embed_vision = nn.Identity() + self.audio_tower = nn.Identity() + self.embed_audio = nn.Identity() + + class _Gemma4Wrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = _Gemma4Core() + self.lm_head = nn.Linear(4, 4, bias=False) + + model = _Gemma4Wrapper() + base_modules = set(Gemma4ForConditionalGenerationGPTQ.get_base_modules(model)) + + assert Gemma4ForConditionalGenerationGPTQ.extract_layers_node() == ["model.language_model.layers"] + assert "model.vision_tower" in base_modules + assert "model.embed_vision" in base_modules + assert "model.audio_tower" in base_modules + assert "model.embed_audio" in base_modules + assert "model.language_model.embed_tokens" in base_modules + assert "model.language_model.embed_tokens_per_layer" in base_modules + assert "model.language_model.per_layer_model_projection" in base_modules + assert "model.language_model.per_layer_projection_norm" in base_modules + + +def test_gemma4_capture_preserves_per_layer_input(): + model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ) + hidden_states = torch.randn(1, 4, 8) + per_layer_input = torch.randn(1, 4, 2) + + captured = model_def.capture_first_layer_positional_inputs( + args=(hidden_states, per_layer_input), + kwargs={}, + batch_device=torch.device("cpu"), + ) + + assert len(captured) == 2 + assert torch.equal(captured[0], hidden_states) + assert torch.equal(captured[1], per_layer_input) + + +def test_gemma4_replay_kwargs_refresh_position_embeddings(): + class _FakeRotary(nn.Module): + def forward(self, x, position_ids, layer_type=None): + marker = 7.0 if layer_type == "full_attention" else 3.0 + shape = (position_ids.shape[0], position_ids.shape[1], 1) + value = torch.full(shape, marker, dtype=x.dtype, device=x.device) + return value, value + 1 + + class _LanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.rotary_emb = _FakeRotary() + + class _Gemma4Core(nn.Module): + def __init__(self): + super().__init__() + self.language_model = _LanguageModel() + + class _Gemma4Wrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = _Gemma4Core() + + model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ) + nn.Module.__init__(model_def) + model_def.model = _Gemma4Wrapper() + + layer = SimpleNamespace(self_attn=SimpleNamespace(layer_type="full_attention")) + hidden_states = torch.randn(1, 4, 8) + refreshed = model_def.prepare_layer_replay_kwargs( + layer=layer, + layer_input=[hidden_states], + additional_inputs={ + "position_ids": torch.arange(4).unsqueeze(0), + "position_embeddings": ("stale",), + }, + target_device=torch.device("cpu"), + ) + + cos, sin = refreshed["position_embeddings"] + assert cos.shape == (1, 4, 1) + assert sin.shape == (1, 4, 1) + assert torch.all(cos == 7) + assert torch.all(sin == 8) + + +def test_gemma4_capture_kwargs_preserve_all_per_layer_inputs(): + class _LanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.rotary_emb = nn.Identity() + self._gptqmodel_cached_all_per_layer_inputs = torch.randn(1, 4, 3, 2) + + class _Gemma4Core(nn.Module): + def __init__(self): + super().__init__() + self.language_model = _LanguageModel() + + class _Gemma4Wrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = _Gemma4Core() + + model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ) + nn.Module.__init__(model_def) + model_def.model = _Gemma4Wrapper() + + captured = model_def.capture_first_layer_input_kwargs( + args=(), + kwargs={}, + batch_device=torch.device("cpu"), + layer_input_kwargs={}, + ) + + assert "__gptqmodel_gemma4_all_per_layer_inputs" in captured + assert captured["__gptqmodel_gemma4_all_per_layer_inputs"].shape == (1, 4, 3, 2) From e0c62a1eaa24f31ebcbda1d49ffd225abb7ae45d Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Fri, 3 Apr 2026 06:52:05 +0800 Subject: [PATCH 2/3] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c430854e4..90c4b4723 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@

## Latest News -* 04/02/2026 [6.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.0): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support. +* 04/03/2026 [6.0.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.2): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `Gemma4`, `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support. * 03/19/2026 [5.8.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.8.0): ✨HF Transformers 5.3.0 support with auto-defusing of `fused` models via pypi pkg: [Defuser](https://github.com/ModelCloud/Defuser). Qwen 3.5 family support added. New fast HF `cpu` kernels for GPTQ/AWQ added. Experimental INT8 `cpu` kernel added for GPTQ. * 03/09/2026 [main]: ✨Qwen 3.5 MoE model support added. New HF Kernel support added for AWQ. HF Kernel for both gptq/awq are now used by default for cpu devices for best performance. New INT8 kernel ported from Intel for gptq. From 9db3fa95b141f1ac7f45f52fd57236f396a8f676 Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Fri, 3 Apr 2026 06:52:27 +0800 Subject: [PATCH 3/3] Bump version from 6.0.0 to 6.0.2 --- gptqmodel/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/version.py b/gptqmodel/version.py index b6d1a8a49..f7a6e896e 100644 --- a/gptqmodel/version.py +++ b/gptqmodel/version.py @@ -7,4 +7,4 @@ # even minor versions are release # 5.2.0 => release, 5.1.0 => devel # micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release -__version__ = "6.0.0" +__version__ = "6.0.2"