Skip to content

Add GGUF loading support for Qwen3-Next (qwen3_next) architecture#44070

Open
rudybear wants to merge 1 commit intohuggingface:mainfrom
rudybear:add-qwen3next-gguf-support
Open

Add GGUF loading support for Qwen3-Next (qwen3_next) architecture#44070
rudybear wants to merge 1 commit intohuggingface:mainfrom
rudybear:add-qwen3next-gguf-support

Conversation

@rudybear
Copy link
Copy Markdown

Summary

  • Add GGUF config mapping, defaults, and tokenizer converter for qwen3_next (Qwen3-Coder-Next, hybrid DeltaNet+Attention MoE, 80B total / 3B active)
  • Add Qwen3NextTensorProcessor handling DeltaNet-specific tensor transforms: QKVZ merge, A_log, conv1d unsqueeze, norm weight offsets, dt_bias mapping, and MoE experts (inherited from Qwen2MoeTensorProcessor)
  • Add architecture name mappings and post-processing for linear_value_head_dim and rope_parameters

Motivation

This enables loading quantized GGUF files (e.g. Q4_K_XL from llama.cpp) directly into Qwen3NextForCausalLM, which unblocks vLLM GGUF support for this architecture.

Verification

Tested against Qwen3-Coder-Next-UD-Q4_K_XL.gguf (46 GB, 843 tensors):

  • Config loading: all 21 GGUF metadata keys mapped correctly
  • Weight map: 843 GGUF tensors → 759 model parameters, 100% coverage
  • Tensor shapes: 20 representative tensors dequantized and verified
  • Tensor transforms: QKVZ roundtrip (exact reconstruction), A_log, conv1d, norms, MoE experts all verified on real data
  • Forward pass: 4-layer subset loaded (66/66 params), no NaN/Inf, reasonable logit range
  • Full 80B model inference requires >160GB memory (not feasible on test hardware)

Tests

  • test_qwen3_next_config_mapping: verifies all 21 config keys, defaults, and tokenizer converter
  • test_qwen3_next_tensor_processor: verifies processor registration and key transforms (conv1d unsqueeze, A_log, norm -1, ssm_norm passthrough)
  • test_qwen3_next_q4_k_xl: skipped (80B model, >160GB memory required)

Dependencies

Requires gguf-py with MODEL_ARCH.QWEN3NEXT support (available in llama.cpp upstream gguf-py, not yet in PyPI release).

🤖 Generated with Claude Code

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: ggml

Enable loading quantized GGUF files for the qwen3_next hybrid
DeltaNet+Attention MoE architecture into Qwen3NextForCausalLM.

Changes in ggml.py:
- Add GGUF_CONFIG_MAPPING for qwen3_next (21 metadata keys including
  MoE, SSM/DeltaNet, and rope parameters)
- Add GGUF_CONFIG_DEFAULTS_MAPPING (norm_topk_prob=True)
- Add GGUF_TO_FAST_CONVERTERS (GGUFQwen2Converter tokenizer)

Changes in modeling_gguf_pytorch_utils.py:
- Add Qwen3NextTensorProcessor handling:
  - attn_qkv + attn_gate -> in_proj_qkvz (reverse split+reshuffle)
  - ssm_a -> A_log (log(-weights))
  - ssm_conv1d unsqueeze (2D -> 3D for Conv1d)
  - norm weights -1 (except ssm_norm)
  - dt_bias fallback mapping
  - MoE experts via inherited Qwen2MoeTensorProcessor
- Add architecture name mappings (qwen3next <-> qwen3_next)
- Add post-processing for linear_value_head_dim and rope_parameters

Tests:
- test_qwen3_next_config_mapping: verify all 21 config keys, defaults,
  and tokenizer converter registration
- test_qwen3_next_tensor_processor: verify processor registration and
  key transforms (conv1d unsqueeze, A_log, norm -1, ssm_norm passthrough)
- test_qwen3_next_q4_k_xl: skipped (80B model, >160GB memory required)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@lexasub
Copy link
Copy Markdown

lexasub commented Apr 16, 2026

@rudybear when merge?

@Rocketknight1
Copy link
Copy Markdown
Member

cc @SunMarc maybe?

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of comments !

Comment on lines +1132 to +1141

def test_qwen3_next_config_mapping(self):
"""Test that Qwen3-Next GGUF config mapping is correctly applied."""
from transformers.integrations.ggml import (
GGUF_CONFIG_DEFAULTS_MAPPING,
GGUF_CONFIG_MAPPING,
GGUF_TO_FAST_CONVERTERS,
GGUFQwen2Converter,
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rework your tests ? base it on the other tests please.

Comment on lines +1182 to +1192
def test_qwen3_next_tensor_processor(self):
"""Test that Qwen3-Next tensor processor is registered and handles key transforms."""
from transformers.modeling_gguf_pytorch_utils import TENSOR_PROCESSORS, Qwen3NextTensorProcessor

self.assertIn("qwen3next", TENSOR_PROCESSORS)
self.assertEqual(TENSOR_PROCESSORS["qwen3next"], Qwen3NextTensorProcessor)

# Test tensor transforms with synthetic data
import numpy as np

config = {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should compare the processor with the Hf version to have a better test

Comment on lines +1222 to +1235
@unittest.skip(reason="Qwen3-Next is 80B params, requires >160GB memory")
def test_qwen3_next_q4_k_xl(self):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-Coder-Next-GGUF",
gguf_file="Qwen3-Coder-Next-UD-Q4_K_XL.gguf",
device_map="auto",
dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
# Expected text to be determined when model can be loaded on suitable hardware
self.assertIsNotNone(tokenizer.decode(out[0], skip_special_tokens=True))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

show the expected text. just checking that there is an output is not enough

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants