Skip to content

[Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline#45207

Merged
stevhliu merged 4 commits intohuggingface:mainfrom
w4nderlust:gemma4-ple-docs
Apr 14, 2026
Merged

[Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline#45207
stevhliu merged 4 commits intohuggingface:mainfrom
w4nderlust:gemma4-ple-docs

Conversation

@w4nderlust
Copy link
Copy Markdown
Contributor

Fixes #45206

What does this PR do?

Adds documentation for the Gemma4 Per-Layer Embeddings (PLE) system, which is currently pretty hard to reverse-engineer from the code alone.

I ran into this while implementing Gemma4 inference from scratch in Rust. The PLE system has several non-obvious aspects that aren't documented anywhere:

  1. hidden_size_per_layer_input (256) is the per-layer dimension, but the actual embedding weight is [vocab, num_layers * 256] = [262144, 8960] because all layers are packed
  2. The embedding is a Gemma4TextScaledWordEmbedding that silently multiplies by sqrt(256) = 16 - this took me a while to track down
  3. The full pipeline has a context-aware projection step (per_layer_model_projection + scale + RMSNorm) that combines with the token lookup before being passed to layers, with specific scale factors (1/sqrt(hidden_size) and 1/sqrt(2))

This PR adds:

  • Expanded config docstring for hidden_size_per_layer_input explaining the packed layout, scaling, and full pipeline
  • Docstrings for get_per_layer_inputs() and project_per_layer_inputs()
  • A comment on the PLE init block pointing to the pipeline methods

Hopefully this saves some pain for others implementing Gemma4 outside of transformers.

The PLE system is complex and underdocumented, which makes it hard
for third-party implementations (llama.cpp, candle, mlx, etc.) to
get right. This adds:

- Config docstring for hidden_size_per_layer_input explaining that
  the actual embedding dim is num_hidden_layers * hidden_size_per_layer_input,
  the embedding is scaled by sqrt(hidden_size_per_layer_input), and
  describing the full two-component pipeline

- Docstring for get_per_layer_inputs() explaining the token-identity
  component and the packed-to-4D reshape

- Docstring for project_per_layer_inputs() explaining the context-aware
  projection, normalization, and combination with scale factors

- Comment on the PLE init block pointing to the pipeline methods

Fixes huggingface#45206
@Rocketknight1
Copy link
Copy Markdown
Member

This is a tricky one - I don't want to bloat the docstrings, but I agree Gemma4 is unusual and this bit can be hard to understand. cc @stevhliu WDYT?

@stevhliu
Copy link
Copy Markdown
Member

stevhliu commented Apr 9, 2026

i agree @Rocketknight1, the config docstring should describe how a parameter works instead of how the model uses it. a better home for this is the gemma4 doc here. the config docstring can say:

hidden_size_per_layer_input (`int`, defaults to 256):
    Per-layer hidden dimension for the PLE system. The actual embedding weight has shape
    `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
    because all layers are packed into a single table. See the [Gemma4](https://huggingface.co/docs/transformers/main/en/model_doc/gemma4#gemma4-vision-model) docs
    for a description of the full PLE pipeline.

the docstrings for get_per_layer_inputs() and project_per_layer_inputs() are fine though :)

@w4nderlust
Copy link
Copy Markdown
Contributor Author

Happy to make the changes yo want guys, jsut tell me exactly and I'll modify the PR :)

Copy link
Copy Markdown
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for improving! 🤗

Comment on lines +97 to +102
Per-layer hidden dimension for the per-layer input embeddings (PLE). This is the
dimension of each layer's individual embedding slice. Note: the actual embedding
weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
(e.g. `[262144, 8960]` for 35 layers x 256 dims) because all layers are packed into
a single embedding table. The embedding is also a `Gemma4TextScaledWordEmbedding`
that scales lookups by `sqrt(hidden_size_per_layer_input)`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Per-layer hidden dimension for the per-layer input embeddings (PLE). This is the
dimension of each layer's individual embedding slice. Note: the actual embedding
weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
(e.g. `[262144, 8960]` for 35 layers x 256 dims) because all layers are packed into
a single embedding table. The embedding is also a `Gemma4TextScaledWordEmbedding`
that scales lookups by `sqrt(hidden_size_per_layer_input)`.
Per-layer hidden dimension for the PLE system. The actual embedding weight has shape
`[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
because all layers are packed into a single table. See the [Gemma4](https://huggingface.co/docs/transformers/main/en/model_doc/gemma4#gemma4-vision-model) docs
for a description of the full PLE pipeline.

Comment on lines +104 to +109
The full PLE pipeline in `Gemma4TextModel` combines two components:
1. **Token-identity**: `embed_tokens_per_layer(input_ids)` (scaled), reshaped to
`[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`
2. **Context-aware**: `per_layer_model_projection(inputs_embeds)` scaled by
`1/sqrt(hidden_size)`, reshaped, then normalized by `per_layer_projection_norm`
These are summed and scaled by `1/sqrt(2)` before being passed to each decoder layer.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to a gemma.md :)

Move the detailed PLE pipeline description from the config docstring
to the Gemma4 model documentation page. The config docstring now just
describes the parameter shape and links to the full docs.
@w4nderlust
Copy link
Copy Markdown
Contributor Author

Thanks for the review! I've pushed the changes:

  • Shortened the hidden_size_per_layer_input config docstring to just describe the parameter shape, with a link to the model docs
  • Moved the full PLE pipeline description to a new section in gemma4.md under "Per-Layer Embeddings (PLE)"
  • Kept the method docstrings for get_per_layer_inputs() and project_per_layer_inputs() as-is

Let me know if you'd like any further adjustments!

Copy link
Copy Markdown
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a few nits! make sure you edit the modular_gemma4.py file instead of the modeling_gemma4.py file to pass the CI :)

Comment thread docs/source/en/model_doc/gemma4.md Outdated

### Per-Layer Embeddings (PLE)

Gemma 4 introduces a **Per-Layer Embeddings (PLE)** system that feeds an auxiliary residual signal into each decoder layer, rather than relying solely on a single shared embedding at the input.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Gemma 4 introduces a **Per-Layer Embeddings (PLE)** system that feeds an auxiliary residual signal into each decoder layer, rather than relying solely on a single shared embedding at the input.
Gemma 4 introduces a Per-Layer Embeddings (PLE) system that feeds an auxiliary residual signal into each decoder layer, rather than relying solely on a single shared embedding at the input.

Comment thread docs/source/en/model_doc/gemma4.md Outdated
Comment on lines +64 to +69
#### Key config parameters

- `vocab_size_per_layer_input` (default 262144): vocabulary size for PLE.
- `hidden_size_per_layer_input` (default 256): per-layer embedding dimension. The actual embedding weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]` (e.g. `[262144, 8960]` for 35 layers × 256 dims) because all layers are packed into a single embedding table.

#### Pipeline
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the config params are covered in the docstrings so i don't think we need to repeat them here

Suggested change
#### Key config parameters
- `vocab_size_per_layer_input` (default 262144): vocabulary size for PLE.
- `hidden_size_per_layer_input` (default 256): per-layer embedding dimension. The actual embedding weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]` (e.g. `[262144, 8960]` for 35 layers × 256 dims) because all layers are packed into a single embedding table.
#### Pipeline

Comment thread docs/source/en/model_doc/gemma4.md Outdated

PLE combines two components that are summed and scaled by `1/√2` before being fed to each decoder layer:

1. **Token-identity** (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1. **Token-identity** (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`.
1. Token-identity (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`.

Comment thread docs/source/en/model_doc/gemma4.md Outdated
PLE combines two components that are summed and scaled by `1/√2` before being fed to each decoder layer:

1. **Token-identity** (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`.
2. **Context-aware** (`project_per_layer_inputs`): projects `inputs_embeds` through `per_layer_model_projection` (a Linear layer), scales by `1/√(hidden_size)`, reshapes to `[batch, seq, num_layers, ple_dim]`, and normalizes with `per_layer_projection_norm` (RMSNorm).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
2. **Context-aware** (`project_per_layer_inputs`): projects `inputs_embeds` through `per_layer_model_projection` (a Linear layer), scales by `1/√(hidden_size)`, reshapes to `[batch, seq, num_layers, ple_dim]`, and normalizes with `per_layer_projection_norm` (RMSNorm).
2. Context-aware (`project_per_layer_inputs`): projects `inputs_embeds` through `per_layer_model_projection` (a Linear layer), scales by `1/√(hidden_size)`, reshapes to `[batch, seq, num_layers, ple_dim]`, and normalizes with `per_layer_projection_norm` (RMSNorm).

- Remove bold formatting and config params section from gemma4.md per review
- Move docstrings and PLE comment from modeling_gemma4.py to modular_gemma4.py
- Revert modeling_gemma4.py (CI regenerates it from modular)
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@stevhliu
Copy link
Copy Markdown
Member

nice, i think the last step is to run make fix-repo to automatically align modeling_gemma4.py with modular_gemma4.py!

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

Copy link
Copy Markdown
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗

@stevhliu stevhliu added this pull request to the merge queue Apr 14, 2026
Merged via the queue into huggingface:main with commit 155db71 Apr 14, 2026
16 checks passed
@w4nderlust w4nderlust deleted the gemma4-ple-docs branch April 14, 2026 20:03
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
…ingface#45207)

* [Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline

The PLE system is complex and underdocumented, which makes it hard
for third-party implementations (llama.cpp, candle, mlx, etc.) to
get right. This adds:

- Config docstring for hidden_size_per_layer_input explaining that
  the actual embedding dim is num_hidden_layers * hidden_size_per_layer_input,
  the embedding is scaled by sqrt(hidden_size_per_layer_input), and
  describing the full two-component pipeline

- Docstring for get_per_layer_inputs() explaining the token-identity
  component and the packed-to-4D reshape

- Docstring for project_per_layer_inputs() explaining the context-aware
  projection, normalization, and combination with scale factors

- Comment on the PLE init block pointing to the pipeline methods

Fixes huggingface#45206

* Address review: move PLE details to model doc, shorten config docstring

Move the detailed PLE pipeline description from the config docstring
to the Gemma4 model documentation page. The config docstring now just
describes the parameter shape and links to the full docs.

* Address review nits: move edits to modular_gemma4.py, simplify gemma4.md

- Remove bold formatting and config params section from gemma4.md per review
- Move docstrings and PLE comment from modeling_gemma4.py to modular_gemma4.py
- Revert modeling_gemma4.py (CI regenerates it from modular)

* fix: run make fix-repo to align modeling_gemma4.py with modular_gemma4.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gemma4: PLE (Per-Layer Embeddings) implementation is underdocumented and config is misleading

4 participants