Skip to content

Allow for "pure" linear attention based Qwen3.5 models #45146

@HallerPatrick

Description

@HallerPatrick

Feature request

This feature requests proposes to allow for the creation of "pure" linear attention Qwen3.5 models. Which means that every layers should be allowed to be a Gated Deltanet token mixer.

Following code should therefore be "allowed":

import torch

from transformers.models import Qwen3_5ForCausalLM, Qwen3_5TextConfig

config = Qwen3_5TextConfig(
    full_attention_interval=0 # This means only GDN layers
) # This would crash due to zero division error

model = Qwen3_5ForCausalLM(config).cuda()

input_ids = torch.tensor([[1, 2, 3, 4, 5]]).cuda()

model(input_ids) # This would crash due accessing transformer_layer mapping in cache. 

Motivation

Qwen3.5 introduces a hybrid architecture with interleaved Gated Deltanet and Softmax Attentio layers. All Qwen published models contain some amount of softmax layers and is therefore implemented in that way.

With the first model implementation that supports a Gated Deltanet as token mixer (and possibly other types of linear attention backbones) the community might be interested in having the possibility to build other "pure" GDN models on top of the qwen3_5 architecture.

Your contribution

The changes to the code would be quite easy to implement.

  1. Adjusting the init for Qwen3_5TextConfig for the case that full_attention_interval is 0:
# configuration_qwen3_5.py, line 108
interval_pattern = kwargs.pop("full_attention_interval", 4)
if interval_pattern <= 0: # Edge case to support full linear attention
    self.layer_types = ["linear_attention"] * self.num_hidden_layers
else:
    self.layer_types = [
        "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
        for i in range(self.num_hidden_layers)
     ]
  1. Adjust the logic for the cache to figure out the current sequence length:

OUTDATED:

# modeling_qwen3_5.py, line 136
layer_idx = (self.transformer_layers[0] if len(self.transformer_layers) > 0 else 0) if layer_idx not in self.transformer_layers else layer_idx

UPDATE:
The current dev branch already is using a transformers wide cache implementation. In the PR (#45148) I changed it here accordingly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions