feat: TP+PP support for Gemma4 VLM (with tied lm_head fix and 31B recipes)#1904
feat: TP+PP support for Gemma4 VLM (with tied lm_head fix and 31B recipes)#1904akoumpa merged 6 commits intoNVIDIA-NeMo:mainfrom
Conversation
Enable tensor- and pipeline-parallel fine-tuning of Gemma4
ConditionalGeneration models:
- Register Gemma4ForConditionalGeneration in `_extract_model_layers`
and `validate_tp_mesh` so the parallelizer recognizes the VLM stack.
- Add pipeline forwards for the Gemma4 text backbone and VLM outer
module, including:
* RoPE dispatch via `config.layer_types[layer_idx]`
* Skipping `tie_weights` on PP stages that lack `embed_tokens` or
`lm_head`
* Handling `image_position_ids` in PP VLM chunking
Signed-off-by: khazic <khazzz1c@gmail.com>
PP last stages own `lm_head` but not `embed_tokens`, so the lm_head tensor cannot share storage with its tied source. HF tied-embedding checkpoints (and DCP checkpoints saved before this fix) omit `lm_head.weight` entirely, which made the DCP planner fail with `Missing key in checkpoint state_dict: lm_head.weight` whenever the last PP stage tried to load. Changes: - Distinguish `uses_tied_lm_head` (config flag) from `has_local_tied_lm_head` (the local stage actually shares the same tensor) in `ModelState`. Saved state dicts now keep `lm_head.weight` on PP last stages. - Add `materialize_missing_tied_lm_head` and `get_tied_lm_head_source_names` helpers that resolve the tied source via `_tied_weights_keys` with sensible HF fallbacks. - In the safetensors fast load path, materialize the missing `lm_head.weight` from the embedding tensor before distributing. - In the standard DCP load path, when the checkpoint metadata lacks `lm_head.weight`, route the embedding source key into the lm_head tensor and rename back after load. This also covers init load from HF tied-embedding base checkpoints. - Add unit tests for the new helpers and PP-last-stage save behavior. Signed-off-by: khazic <khazzz1c@gmail.com>
Two desensitized example configs for fine-tuning Gemma4 31B VLM: - `gemma4_31b_tp4_pp2.yaml`: single-node 8-GPU (TP4 x PP2 x DP1). - `gemma4_31b_tp4_pp4.yaml`: 4-node 32-GPU (TP4 x PP4 x DP2), with a multi-node `torchrun` launch snippet. Both use the public MedPix-VQA dataset and `google/gemma-4-31B-it`. Signed-off-by: khazic <khazzz1c@gmail.com>
|
Training log from a live run on the 4-node TP4 x PP4 x DP2 recipe (Gemma4 31B VLM) on our side — loss is converging and grad norm / tps look stable: Config: 4 nodes x 8 GPUs (32 GPUs total), TP=4, PP=4, DP=2, |
|
/ok to test 1ee4789 |
|
/claude review |
| inner_model = getattr(model, "model", None) | ||
| text_backbone = ( | ||
| getattr(inner_model, "language_model", None) | ||
| if inner_model is not None | ||
| else None | ||
| ) | ||
|
|
||
| if inner_model is not None and text_backbone is not None: | ||
| # VLM with nested text backbone (e.g. Gemma4): patch text backbone and VLM outer | ||
| if patch_inner_model: | ||
| text_backbone.forward = types.MethodType( | ||
| create_pipeline_forward_gemma4_text(), text_backbone | ||
| ) | ||
| if patch_causal_lm_model: | ||
| model.forward = types.MethodType(create_pipeline_forward_gemma4_vlm(), model) |
There was a problem hiding this comment.
Bug: This language_model check is not specific to Gemma4 — many other VLMs (KimiVL, Mistral4, Qwen3VL MoE, Kimi K25 VL, LlavaOneVision, etc.) also have model.language_model. As written, all of those models will get the Gemma4-specific create_pipeline_forward_gemma4_text / create_pipeline_forward_gemma4_vlm patched onto them during PP, which will break them.
This needs to be gated on the model actually being Gemma4, e.g.:
from nemo_automodel.components.models.gemma4_moe.model import Gemma4ForConditionalGeneration
is_gemma4_vlm = isinstance(model, Gemma4ForConditionalGeneration)
# or check model.config.model_type == "gemma4" if you want to avoid the importAnd then the branch becomes if inner_model is not None and is_gemma4_vlm:.
There was a problem hiding this comment.
Hi @akoumpa — good catch, this was already addressed in a follow-up commit cf772a4d
| tied_keys = getattr(model, "_tied_weights_keys", None) | ||
| if isinstance(tied_keys, dict): | ||
| for target_name, source_name in tied_keys.items(): |
There was a problem hiding this comment.
Nit/potential bug: HF upstream defines _tied_weights_keys as a list of strings (e.g., ["lm_head.weight"]), not a dict. The NeMo custom models use a dict convention ({"lm_head.weight": "model.embed_tokens.weight"}), but Gemma4ForConditionalGeneration doesn't define _tied_weights_keys at all, so it will inherit the HF parent's list.
This isinstance(tied_keys, dict) check will silently skip the HF list case. The hardcoded fallback candidates below save this in practice, but the code doesn't match its stated intent. Consider also handling list/set types here (where items are target names and you'd need to infer the source differently), or at minimum add a comment noting this is intentionally dict-only for NeMo models.
|
|
||
|
|
||
| # ============================================================================= | ||
| # Tests for _reinit_non_persistent_buffers |
There was a problem hiding this comment.
Missing test coverage: The patch_hf_model_for_pp changes in hf_utils.py add a new VLM branch (model with model.language_model), but the existing TestPatchHfModelForPp tests don't cover this case. Given the regression risk (the new branch currently activates for all VLMs with language_model, not just Gemma4), this needs test coverage.
- Sort imports in checkpointing.py and parallelizer.py (ruff I001).
- Replace the import of the MoE-specific
`nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration`
in parallelizer.py with the HF base class from
`transformers.models.gemma4.modeling_gemma4`, which is what dense
Gemma4 VLM actually instantiates. This also lifts the tach
`distributed -> checkpoint` contract violation introduced by the
previous indirect import chain.
- Re-apply ruff format on two touched files in
components/{checkpoint,distributed/pipelining}.
- Update unit tests for the renamed / stricter tied-lm-head semantics:
* `test_model_state_disables_tied_embeddings_for_non_tied_models`
asserts the new `uses_tied_lm_head` / `has_local_tied_lm_head`
attributes instead of the old `is_tied_lm_head`.
* `test_validate_model_with_tied_embeddings` and
`test_validate_multiple_issues` now construct a MockModel with
`lm_head` and `embed_tokens` whose weights actually share storage,
so the validator's stricter "truly tied" check triggers.
* Add a no-op `update_seq_len` to `_MockAutoPipeline` so the
`TestForwardBackwardStepPP` tests can exercise the VLM PP path
that now calls `self.pp.update_seq_len(input_ids.shape[1])`
before each schedule step.
Signed-off-by: khazic <khazzz1c@gmail.com>
Two related fixes for the L2 integration test failures on the PR: 1. `checkpointing.py` `_maybe_build_consolidated_index` still accessed the renamed `model_state.is_tied_lm_head`, which raised `AttributeError` on save for any tied-embedding model and broke L2_Retrieval / L2_HF_DCP / L2_HF_Transformer_Finetune / L2_Pretrain_and_KD. Switch to `getattr(model_state, "has_local_tied_lm_head", False)` and clarify the intent with a comment. 2. `has_local_tied_lm_head` used an `is`-identity check between `lm_head.weight` and `embed_tokens.weight`. That was too strict: after FSDP2 / TP sharding both weights are wrapped in separate `DTensor`s and `is`-equality is broken even though HF's `tie_weights()` can still relink them on load. The single-rank Gemma3 VL 4B DCP test saw the `lm_head` key leak onto disk because the save path no longer recognized the model as locally tied. Replace the `is` check with `is_tied_word_embeddings(model) and both local weights exist`. This preserves the PP cases (last stage has `lm_head` but no `embed_tokens` => False; first stage has `embed_tokens` but no `lm_head` => False), fixes the single-rank FSDP-sharded case, and keeps the existing unit tests for the PP-last-stage-like partition passing (that mock has no `embed_tokens`, so the check still returns False for it). Signed-off-by: khazic <khazzz1c@gmail.com>
|
Pushed
Verified locally:
|
|
/ok to test 7c34c9f |
|
|
||
| step_scheduler: | ||
| global_batch_size: 8 | ||
| local_batch_size: 1 |
There was a problem hiding this comment.
Good catch! Technically lbs=1 still runs without errors — the code checks n_microbatches < len(stages) where len(stages) is the per-rank stage count (= 1 for 1F1B), so no warning is triggered. But with only 1 microbatch for 2 PP stages there is zero pipeline overlap (100% bubble). Fixed to local_batch_size: 2 so we get 2 microbatches and proper 1F1B overlap.
| "multi_modal_projector", | ||
| "multimodal_projector", | ||
| "vision_projector", | ||
| "audio_projector", |
There was a problem hiding this comment.
@khazic missing embed_vision for gemma4 for pp?
There was a problem hiding this comment.
Good catch! When testing I was only running pure-text inputs, so the multimodal path wasn't exercised and embed_vision slipped through. Fixed by adding it to MULTIMODAL_SUFFIXES in hf_utils.py.
Summary
register the model in the parallelizer and add Gemma4-specific
pipeline forwards (text backbone + VLM outer), including
layer_types-based RoPE dispatch, safetie_weightsskipping forPP stages without
embed_tokens/lm_head, and PP chunking thathandles
image_position_ids.lm_headloading on the PP last stage. PP last stages ownlm_headbut notembed_tokens, so they cannot share storagelocally. HF tied-embedding checkpoints (and DCP checkpoints saved
before this fix) omit
lm_head.weightentirely, which made the DCPplanner fail with
Missing key in checkpoint state_dict: lm_head.weight. The fix:uses_tied_lm_head(config flag) fromhas_local_tied_lm_head(true only when locallm_headandembed_tokensactually share storage). Saved state dicts now keeplm_head.weighton PP last stages.materialize_missing_tied_lm_head/get_tied_lm_head_source_nameshelpers using_tied_weights_keysplus HF fallbacks.lm_head.weightfrom the embedding tensor.lm_head.weightis absentfrom checkpoint metadata, routes the embedding source key into
the lm_head tensor and renames back after load. This now also
applies on init load so a fresh HF tied-embedding base loads
cleanly into a PP last stage.
examples/vlm_finetune/gemma4/gemma4_31b_tp4_pp2.yaml— singlenode, 8 GPUs (TP4 x PP2 x DP1).
examples/vlm_finetune/gemma4/gemma4_31b_tp4_pp4.yaml— 4 nodes,32 GPUs (TP4 x PP4 x DP2), with a multi-node
torchrunsnippet.Both use
google/gemma-4-31B-itand the public MedPix-VQA dataset.Test plan
tests/unit_tests/checkpoint/test_checkpointing.py(
has_local_tied_lm_head,materialize_missing_tied_lm_head,PP-last-stage save behavior).
gemma4_31b_tp4_pp2.yaml(single node, 8 GPU).gemma4_31b_tp4_pp4.yaml(4 nodes, 32 GPU).