[Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline#45207
[Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline#45207stevhliu merged 4 commits intohuggingface:mainfrom
Conversation
The PLE system is complex and underdocumented, which makes it hard for third-party implementations (llama.cpp, candle, mlx, etc.) to get right. This adds: - Config docstring for hidden_size_per_layer_input explaining that the actual embedding dim is num_hidden_layers * hidden_size_per_layer_input, the embedding is scaled by sqrt(hidden_size_per_layer_input), and describing the full two-component pipeline - Docstring for get_per_layer_inputs() explaining the token-identity component and the packed-to-4D reshape - Docstring for project_per_layer_inputs() explaining the context-aware projection, normalization, and combination with scale factors - Comment on the PLE init block pointing to the pipeline methods Fixes huggingface#45206
|
This is a tricky one - I don't want to bloat the docstrings, but I agree Gemma4 is unusual and this bit can be hard to understand. cc @stevhliu WDYT? |
|
i agree @Rocketknight1, the config docstring should describe how a parameter works instead of how the model uses it. a better home for this is the gemma4 doc here. the config docstring can say: hidden_size_per_layer_input (`int`, defaults to 256):
Per-layer hidden dimension for the PLE system. The actual embedding weight has shape
`[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]`
because all layers are packed into a single table. See the [Gemma4](https://huggingface.co/docs/transformers/main/en/model_doc/gemma4#gemma4-vision-model) docs
for a description of the full PLE pipeline.the docstrings for |
|
Happy to make the changes yo want guys, jsut tell me exactly and I'll modify the PR :) |
| Per-layer hidden dimension for the per-layer input embeddings (PLE). This is the | ||
| dimension of each layer's individual embedding slice. Note: the actual embedding | ||
| weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]` | ||
| (e.g. `[262144, 8960]` for 35 layers x 256 dims) because all layers are packed into | ||
| a single embedding table. The embedding is also a `Gemma4TextScaledWordEmbedding` | ||
| that scales lookups by `sqrt(hidden_size_per_layer_input)`. |
There was a problem hiding this comment.
| Per-layer hidden dimension for the per-layer input embeddings (PLE). This is the | |
| dimension of each layer's individual embedding slice. Note: the actual embedding | |
| weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]` | |
| (e.g. `[262144, 8960]` for 35 layers x 256 dims) because all layers are packed into | |
| a single embedding table. The embedding is also a `Gemma4TextScaledWordEmbedding` | |
| that scales lookups by `sqrt(hidden_size_per_layer_input)`. | |
| Per-layer hidden dimension for the PLE system. The actual embedding weight has shape | |
| `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]` | |
| because all layers are packed into a single table. See the [Gemma4](https://huggingface.co/docs/transformers/main/en/model_doc/gemma4#gemma4-vision-model) docs | |
| for a description of the full PLE pipeline. |
| The full PLE pipeline in `Gemma4TextModel` combines two components: | ||
| 1. **Token-identity**: `embed_tokens_per_layer(input_ids)` (scaled), reshaped to | ||
| `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]` | ||
| 2. **Context-aware**: `per_layer_model_projection(inputs_embeds)` scaled by | ||
| `1/sqrt(hidden_size)`, reshaped, then normalized by `per_layer_projection_norm` | ||
| These are summed and scaled by `1/sqrt(2)` before being passed to each decoder layer. |
Move the detailed PLE pipeline description from the config docstring to the Gemma4 model documentation page. The config docstring now just describes the parameter shape and links to the full docs.
|
Thanks for the review! I've pushed the changes:
Let me know if you'd like any further adjustments! |
stevhliu
left a comment
There was a problem hiding this comment.
just a few nits! make sure you edit the modular_gemma4.py file instead of the modeling_gemma4.py file to pass the CI :)
|
|
||
| ### Per-Layer Embeddings (PLE) | ||
|
|
||
| Gemma 4 introduces a **Per-Layer Embeddings (PLE)** system that feeds an auxiliary residual signal into each decoder layer, rather than relying solely on a single shared embedding at the input. |
There was a problem hiding this comment.
| Gemma 4 introduces a **Per-Layer Embeddings (PLE)** system that feeds an auxiliary residual signal into each decoder layer, rather than relying solely on a single shared embedding at the input. | |
| Gemma 4 introduces a Per-Layer Embeddings (PLE) system that feeds an auxiliary residual signal into each decoder layer, rather than relying solely on a single shared embedding at the input. |
| #### Key config parameters | ||
|
|
||
| - `vocab_size_per_layer_input` (default 262144): vocabulary size for PLE. | ||
| - `hidden_size_per_layer_input` (default 256): per-layer embedding dimension. The actual embedding weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]` (e.g. `[262144, 8960]` for 35 layers × 256 dims) because all layers are packed into a single embedding table. | ||
|
|
||
| #### Pipeline |
There was a problem hiding this comment.
the config params are covered in the docstrings so i don't think we need to repeat them here
| #### Key config parameters | |
| - `vocab_size_per_layer_input` (default 262144): vocabulary size for PLE. | |
| - `hidden_size_per_layer_input` (default 256): per-layer embedding dimension. The actual embedding weight has shape `[vocab_size_per_layer_input, num_hidden_layers * hidden_size_per_layer_input]` (e.g. `[262144, 8960]` for 35 layers × 256 dims) because all layers are packed into a single embedding table. | |
| #### Pipeline |
|
|
||
| PLE combines two components that are summed and scaled by `1/√2` before being fed to each decoder layer: | ||
|
|
||
| 1. **Token-identity** (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`. |
There was a problem hiding this comment.
| 1. **Token-identity** (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`. | |
| 1. Token-identity (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`. |
| PLE combines two components that are summed and scaled by `1/√2` before being fed to each decoder layer: | ||
|
|
||
| 1. **Token-identity** (`get_per_layer_inputs`): looks up `input_ids` in `embed_tokens_per_layer`, a `Gemma4TextScaledWordEmbedding` that multiplies by `√(hidden_size_per_layer_input)`. The packed output is reshaped from `[batch, seq, num_hidden_layers * hidden_size_per_layer_input]` to `[batch, seq, num_hidden_layers, hidden_size_per_layer_input]`. | ||
| 2. **Context-aware** (`project_per_layer_inputs`): projects `inputs_embeds` through `per_layer_model_projection` (a Linear layer), scales by `1/√(hidden_size)`, reshapes to `[batch, seq, num_layers, ple_dim]`, and normalizes with `per_layer_projection_norm` (RMSNorm). |
There was a problem hiding this comment.
| 2. **Context-aware** (`project_per_layer_inputs`): projects `inputs_embeds` through `per_layer_model_projection` (a Linear layer), scales by `1/√(hidden_size)`, reshapes to `[batch, seq, num_layers, ple_dim]`, and normalizes with `per_layer_projection_norm` (RMSNorm). | |
| 2. Context-aware (`project_per_layer_inputs`): projects `inputs_embeds` through `per_layer_model_projection` (a Linear layer), scales by `1/√(hidden_size)`, reshapes to `[batch, seq, num_layers, ple_dim]`, and normalizes with `per_layer_projection_norm` (RMSNorm). |
- Remove bold formatting and config params section from gemma4.md per review - Move docstrings and PLE comment from modeling_gemma4.py to modular_gemma4.py - Revert modeling_gemma4.py (CI regenerates it from modular)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
nice, i think the last step is to run |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma4 |
…ingface#45207) * [Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline The PLE system is complex and underdocumented, which makes it hard for third-party implementations (llama.cpp, candle, mlx, etc.) to get right. This adds: - Config docstring for hidden_size_per_layer_input explaining that the actual embedding dim is num_hidden_layers * hidden_size_per_layer_input, the embedding is scaled by sqrt(hidden_size_per_layer_input), and describing the full two-component pipeline - Docstring for get_per_layer_inputs() explaining the token-identity component and the packed-to-4D reshape - Docstring for project_per_layer_inputs() explaining the context-aware projection, normalization, and combination with scale factors - Comment on the PLE init block pointing to the pipeline methods Fixes huggingface#45206 * Address review: move PLE details to model doc, shorten config docstring Move the detailed PLE pipeline description from the config docstring to the Gemma4 model documentation page. The config docstring now just describes the parameter shape and links to the full docs. * Address review nits: move edits to modular_gemma4.py, simplify gemma4.md - Remove bold formatting and config params section from gemma4.md per review - Move docstrings and PLE comment from modeling_gemma4.py to modular_gemma4.py - Revert modeling_gemma4.py (CI regenerates it from modular) * fix: run make fix-repo to align modeling_gemma4.py with modular_gemma4.py
Fixes #45206
What does this PR do?
Adds documentation for the Gemma4 Per-Layer Embeddings (PLE) system, which is currently pretty hard to reverse-engineer from the code alone.
I ran into this while implementing Gemma4 inference from scratch in Rust. The PLE system has several non-obvious aspects that aren't documented anywhere:
hidden_size_per_layer_input(256) is the per-layer dimension, but the actual embedding weight is[vocab, num_layers * 256]=[262144, 8960]because all layers are packedGemma4TextScaledWordEmbeddingthat silently multiplies bysqrt(256) = 16- this took me a while to track downper_layer_model_projection+ scale + RMSNorm) that combines with the token lookup before being passed to layers, with specific scale factors (1/sqrt(hidden_size)and1/sqrt(2))This PR adds:
hidden_size_per_layer_inputexplaining the packed layout, scaling, and full pipelineget_per_layer_inputs()andproject_per_layer_inputs()Hopefully this saves some pain for others implementing Gemma4 outside of transformers.