Skip to content

Add heterogeneous config support (per-layer configuration)#45333

Open
eladsegal wants to merge 4 commits intohuggingface:mainfrom
eladsegal:heterogeneous-config
Open

Add heterogeneous config support (per-layer configuration)#45333
eladsegal wants to merge 4 commits intohuggingface:mainfrom
eladsegal:heterogeneous-config

Conversation

@eladsegal
Copy link
Copy Markdown
Contributor

What does this PR do?

Adds heterogeneous model support - the ability for individual layers to differ from the global config (e.g., different intermediate_size, num_key_value_heads) and to skip sub-modules entirely (MLP, attention, etc.). This enables models where layers are not uniform, as in pruned, distilled, or NAS-derived architectures.

Examples of such models:

Model Derived from
nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 meta-llama/Llama-3.3-70B-Instruct
nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 meta-llama/Llama-3.1-405B-Instruct
nvidia/gpt-oss-puzzle-88B openai/gpt-oss-120b

These models previously required trust_remote_code=True to modify the classes of the model they are derived from. With this PR (and the follow-up modeling PR #45332), heterogeneous support requires just a few lines.

This PR contains the configuration layer only. The modeling, cache, and masking changes that consume this config are in #45332.

How it works

Configuration (per_layer_config)

A new per_layer_config parameter on PreTrainedConfig maps layer indices to attribute overrides:

from transformers import LlamaConfig

config = LlamaConfig(
    ...,
    per_layer_config={
        0: {"intermediate_size": 64},
        2: {"intermediate_size": 96, "skip_attention": True},
    },
)

Under the hood, apply_heterogeneous_config validates the overrides, computes fallback values, and stores a HeterogeneitySpec on the config. When all layers agree on an attribute value, it's promoted back to a global attribute, for clarity. The config supports full save_pretrained / from_pretrained round-trips (keys are zero-padded for correct JSON sort order).

Accessing a per-layer attribute on the global config raises AttributeError with a helpful message directing to config.get_full_layer_config(i).

Key changes

  • New src/transformers/heterogeneity/ package - configuration utilities (LayerConfig, HeterogeneitySpec, validation, serialization)
  • configuration_utils.py - per_layer_config property, is_heterogeneous, get_full_layer_config(), serialization hooks, and __getattribute__ guard for per-layer attributes

Tests

Test suite in tests/heterogeneity/test_configuration_utils.py covering per-layer overrides and fallback, uniform value promotion, validation, and save/load round-trip.

Who can review?

@ArthurZucker
@hmellor

@eladsegal
Copy link
Copy Markdown
Contributor Author

@askliar
Related to vllm-project/vllm#36512

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed this first: #45332 (review) !

I think we want stuff tto be explicit and "simple".

We do have 2 choices basically:

  1. per_layer_config: List[PreTrainedConfig]

  2. per_layer_hidden_size: List[int]

  3. Per layer configs would be quite simple, we find a way to have a minimal serialization to on serialize the actual per layer. I like that a bit less? But its the most convenient for modeling changes / inheritance

  4. With this, we can have PreTrainedConfig just init the per-layer-configs based on the lists. This means nice to understand serialization, and you just parse the per_layer_<key> into Config(<key> = value).

WDYT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants