Skip to content

Add heterogeneous model support (per-layer config and modeling)#45332

Open
eladsegal wants to merge 9 commits intohuggingface:mainfrom
eladsegal:heterogeneous
Open

Add heterogeneous model support (per-layer config and modeling)#45332
eladsegal wants to merge 9 commits intohuggingface:mainfrom
eladsegal:heterogeneous

Conversation

@eladsegal
Copy link
Copy Markdown
Contributor

@eladsegal eladsegal commented Apr 9, 2026

What does this PR do?

Adds heterogeneous model support - the ability for individual layers to differ from the global config (e.g., different intermediate_size, num_key_value_heads) and to skip sub-modules entirely (MLP, attention, etc.). This enables models where layers are not uniform, as in pruned, distilled, or NAS-derived architectures.

Examples of such models:

Model Derived from
nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 meta-llama/Llama-3.3-70B-Instruct
nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 meta-llama/Llama-3.1-405B-Instruct
nvidia/gpt-oss-puzzle-88B openai/gpt-oss-120b

These models previously required trust_remote_code=True to modify the classes of the model they are derived from. With this PR, adding heterogeneous support requires just a few lines.

How it works

Configuration (per_layer_config)

A new per_layer_config parameter on PreTrainedConfig maps layer indices to attribute overrides:

from transformers import LlamaConfig

config = LlamaConfig(
    ...,
    per_layer_config={
        0: {"intermediate_size": 64},
        2: {"intermediate_size": 96, "skip_attention": True},
    },
)

Under the hood, apply_heterogeneous_config validates the overrides, computes fallback values, and stores a HeterogeneitySpec on the config. When all layers agree on an attribute value, it's promoted back to a global attribute, for clarity. The config supports full save_pretrained / from_pretrained round-trips (keys are zero-padded for correct JSON sort order).

Accessing a per-layer attribute on the global config raises AttributeError with a helpful message directing to config.get_full_layer_config(i).

Modeling (automatic layer patching)

To opt into heterogeneity, a model sets _layer_cls on its PreTrainedModel subclass. At model init time, the framework monkey-patches layer_cls.__init__ to:

  1. Resolve the layer index (from function arguments or the call stack)
  2. Call config.get_full_layer_config(layer_idx) to merge global + per-layer overrides
  3. Pass the resolved config to the original __init__
  4. For layers with skip_* flags, replace the corresponding sub-modules with no-op replacements

The patching uses a ContextVar to pass per-model context to the layer __init__ wrapper. Cleanup happens automatically after model init (via __init_subclass__ wrapping).

For sub-module skipping, a model additionally defines _skip_descriptors:

from transformers.heterogeneity import ReturnEntry, get_skip_replacement

class LlamaPreTrainedModel(PreTrainedModel):
    ...
    _layer_cls = LlamaDecoderLayer
    _skip_descriptors = {
        "attention": {
            "input_layernorm": get_skip_replacement(
                LlamaRMSNorm, ReturnEntry(arg_name="hidden_states", transform=lambda x: x)
            ),
            "self_attn": get_skip_replacement(
                LlamaAttention,
                [ReturnEntry(arg_name="hidden_states", transform=torch.zeros_like), None],
            ),
        },
        "mlp": {
            "post_attention_layernorm": get_skip_replacement(
                LlamaRMSNorm, ReturnEntry(arg_name="hidden_states", transform=lambda x: x)
            ),
            "mlp": get_skip_replacement(
                LlamaMLP, ReturnEntry(arg_name="x", transform=torch.zeros_like)
            ),
        },
    }

Note: _layer_cls alone is sufficient for attribute heterogeneity (varying dimensions across layers). _skip_descriptors is only needed if the model also wants to support skipping entire sub-modules.

Cache and masking

When sliding_window or attention_chunk_size varies per layer:

  • DynamicCache and StaticCache create per-layer sliding window layers with the correct window sizes
  • create_sliding_window_causal_mask / create_chunked_causal_mask return a dict[int, Tensor] keyed by distinct window sizes (deduplicated), and each layer's forward selects its own mask

Key changes

  • New src/transformers/heterogeneity/ package - configuration utilities (LayerConfig, HeterogeneitySpec, validation, serialization) and modeling utilities (layer init patching, skip replacements via get_skip_replacement/ReturnEntry, per-layer forward patching to select the correct mask from the mask dict)
  • configuration_utils.py - per_layer_config property, is_heterogeneous, get_full_layer_config(), serialization hooks, and __getattribute__ guard for per-layer attributes
  • modeling_utils.py - auto-calls apply_heterogeneous_modeling at model init, wraps subclass __init__ for cleanup via __init_subclass__
  • cache_utils.py - DynamicCache and StaticCache resolve per-layer sliding window sizes for heterogeneous configs
  • masking_utils.py - sliding window and chunked causal mask functions return a dict of masks when the window/chunk size varies across layers
  • models/gpt_oss/modeling_gpt_oss.py and integrations/mxfp4.py - Fix homogeneity assumptions in order to support heterogeneous configurations

Tests

Comprehensive test suite in tests/heterogeneity/ covering 4 architectures (Llama, GptOss, Llama4, NemotronH):
Covers structure verification, forward/generate equivalence against manually-constructed reference models, config and model save/load round-trips, heterogeneous cache and mask behavior, and edge cases.


Concrete example: Llama-3.1-8B-Instruct with 4 layers of attention skipped

Code snippet (click to expand)
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
from transformers.heterogeneity import ReturnEntry, get_skip_replacement
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaPreTrainedModel,
    LlamaRMSNorm,
)


LlamaPreTrainedModel._layer_cls = LlamaDecoderLayer
LlamaPreTrainedModel._skip_descriptors = {
    "attention": {
        "input_layernorm": get_skip_replacement(
            LlamaRMSNorm, ReturnEntry(arg_name="hidden_states", transform=lambda x: x)
        ),
        "self_attn": get_skip_replacement(
            LlamaAttention,
            [ReturnEntry(arg_name="hidden_states", transform=torch.zeros_like), None],
        ),
    },
}

base_model_id = "meta-llama/Llama-3.1-8B-Instruct"

# 1. Load model with per-layer overrides
model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
    dtype="auto",
    per_layer_config={
        19: {"skip_attention": True},
        20: {"skip_attention": True},
        22: {"skip_attention": True},
        23: {"skip_attention": True},
    },
)

# 2. Generate
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)
generation_config = GenerationConfig.from_pretrained(base_model_id)
generation_config.max_new_tokens = 256

messages = [
    {"role": "user", "content": "Hey, how are you doing today?"},
]

outputs = pipe(
    messages,
    generation_config=generation_config,
)

print(outputs[0]["generated_text"][-1]["content"])

This assumes LlamaPreTrainedModel has _layer_cls and _skip_descriptors defined (see Modeling above).

Who can review?

@ArthurZucker
@hmellor

@eladsegal
Copy link
Copy Markdown
Contributor Author

@askliar
Related to vllm-project/vllm#36512

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! thanks for the PR!
One thing that's super important to stay aligned with transformers philo is that is not somthing we want to ship to old models! We'll be super happy to create new models with modular that do have heterogenous configs per layer, but adding this to llama break llama. GPT OSS does NOT have this, as such it should not support it. Happy to have heterogen_gpt_oss !

This removes code paths, and the if is_heterogeneous etc.
Given this, I think it would mostly be a change in the config to have per_layer_config.

We can find a good way to serialize all args in order to reduce the bloat potentially (meaning all attributes are list, at init time this creates a list of config per-layer) this simplifies code changes in modeling_heterogenous_gpt_oss for example, that would just get per_layer_config[i]

Furthermore I don't understand the skip idea can you give a concrete example? you want to skip an entire layer, but you can just set it to nn.Identity for example? or just re-index them?

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt_oss

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45332&sha=7fcf9a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants