Add heterogeneous model support (per-layer config and modeling)#45332
Add heterogeneous model support (per-layer config and modeling)#45332eladsegal wants to merge 9 commits intohuggingface:mainfrom
Conversation
|
@askliar |
ArthurZucker
left a comment
There was a problem hiding this comment.
Hey! thanks for the PR!
One thing that's super important to stay aligned with transformers philo is that is not somthing we want to ship to old models! We'll be super happy to create new models with modular that do have heterogenous configs per layer, but adding this to llama break llama. GPT OSS does NOT have this, as such it should not support it. Happy to have heterogen_gpt_oss !
This removes code paths, and the if is_heterogeneous etc.
Given this, I think it would mostly be a change in the config to have per_layer_config.
We can find a good way to serialize all args in order to reduce the bloat potentially (meaning all attributes are list, at init time this creates a list of config per-layer) this simplifies code changes in modeling_heterogenous_gpt_oss for example, that would just get per_layer_config[i]
Furthermore I don't understand the skip idea can you give a concrete example? you want to skip an entire layer, but you can just set it to nn.Identity for example? or just re-index them?
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gpt_oss |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45332&sha=7fcf9a |
What does this PR do?
Adds heterogeneous model support - the ability for individual layers to differ from the global config (e.g., different
intermediate_size,num_key_value_heads) and to skip sub-modules entirely (MLP, attention, etc.). This enables models where layers are not uniform, as in pruned, distilled, or NAS-derived architectures.Examples of such models:
These models previously required
trust_remote_code=Trueto modify the classes of the model they are derived from. With this PR, adding heterogeneous support requires just a few lines.How it works
Configuration (
per_layer_config)A new
per_layer_configparameter onPreTrainedConfigmaps layer indices to attribute overrides:Under the hood,
apply_heterogeneous_configvalidates the overrides, computes fallback values, and stores aHeterogeneitySpecon the config. When all layers agree on an attribute value, it's promoted back to a global attribute, for clarity. The config supports fullsave_pretrained/from_pretrainedround-trips (keys are zero-padded for correct JSON sort order).Accessing a per-layer attribute on the global config raises
AttributeErrorwith a helpful message directing toconfig.get_full_layer_config(i).Modeling (automatic layer patching)
To opt into heterogeneity, a model sets
_layer_clson itsPreTrainedModelsubclass. At model init time, the framework monkey-patcheslayer_cls.__init__to:config.get_full_layer_config(layer_idx)to merge global + per-layer overrides__init__skip_*flags, replace the corresponding sub-modules with no-op replacementsThe patching uses a
ContextVarto pass per-model context to the layer__init__wrapper. Cleanup happens automatically after model init (via__init_subclass__wrapping).For sub-module skipping, a model additionally defines
_skip_descriptors:Note:
_layer_clsalone is sufficient for attribute heterogeneity (varying dimensions across layers)._skip_descriptorsis only needed if the model also wants to support skipping entire sub-modules.Cache and masking
When
sliding_windoworattention_chunk_sizevaries per layer:DynamicCacheandStaticCachecreate per-layer sliding window layers with the correct window sizescreate_sliding_window_causal_mask/create_chunked_causal_maskreturn adict[int, Tensor]keyed by distinct window sizes (deduplicated), and each layer's forward selects its own maskKey changes
src/transformers/heterogeneity/package - configuration utilities (LayerConfig,HeterogeneitySpec, validation, serialization) and modeling utilities (layer init patching, skip replacements viaget_skip_replacement/ReturnEntry, per-layer forward patching to select the correct mask from the mask dict)configuration_utils.py-per_layer_configproperty,is_heterogeneous,get_full_layer_config(), serialization hooks, and__getattribute__guard for per-layer attributesmodeling_utils.py- auto-callsapply_heterogeneous_modelingat model init, wraps subclass__init__for cleanup via__init_subclass__cache_utils.py-DynamicCacheandStaticCacheresolve per-layer sliding window sizes for heterogeneous configsmasking_utils.py- sliding window and chunked causal mask functions return adictof masks when the window/chunk size varies across layersmodels/gpt_oss/modeling_gpt_oss.pyandintegrations/mxfp4.py- Fix homogeneity assumptions in order to support heterogeneous configurationsTests
Comprehensive test suite in
tests/heterogeneity/covering 4 architectures (Llama, GptOss, Llama4, NemotronH):Covers structure verification, forward/generate equivalence against manually-constructed reference models, config and model save/load round-trips, heterogeneous cache and mask behavior, and edge cases.
Concrete example: Llama-3.1-8B-Instruct with 4 layers of attention skipped
Code snippet (click to expand)
This assumes
LlamaPreTrainedModelhas_layer_clsand_skip_descriptorsdefined (see Modeling above).Who can review?
@ArthurZucker
@hmellor