fix: Gemma4 PP follow-ups — embed_vision stage-0 assignment and lbs=2 for PP2 recipe#1910
Closed
khazic wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
Closed
fix: Gemma4 PP follow-ups — embed_vision stage-0 assignment and lbs=2 for PP2 recipe#1910khazic wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
khazic wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
Conversation
Enable tensor- and pipeline-parallel fine-tuning of Gemma4
ConditionalGeneration models:
- Register Gemma4ForConditionalGeneration in `_extract_model_layers`
and `validate_tp_mesh` so the parallelizer recognizes the VLM stack.
- Add pipeline forwards for the Gemma4 text backbone and VLM outer
module, including:
* RoPE dispatch via `config.layer_types[layer_idx]`
* Skipping `tie_weights` on PP stages that lack `embed_tokens` or
`lm_head`
* Handling `image_position_ids` in PP VLM chunking
Signed-off-by: khazic <khazzz1c@gmail.com>
PP last stages own `lm_head` but not `embed_tokens`, so the lm_head tensor cannot share storage with its tied source. HF tied-embedding checkpoints (and DCP checkpoints saved before this fix) omit `lm_head.weight` entirely, which made the DCP planner fail with `Missing key in checkpoint state_dict: lm_head.weight` whenever the last PP stage tried to load. Changes: - Distinguish `uses_tied_lm_head` (config flag) from `has_local_tied_lm_head` (the local stage actually shares the same tensor) in `ModelState`. Saved state dicts now keep `lm_head.weight` on PP last stages. - Add `materialize_missing_tied_lm_head` and `get_tied_lm_head_source_names` helpers that resolve the tied source via `_tied_weights_keys` with sensible HF fallbacks. - In the safetensors fast load path, materialize the missing `lm_head.weight` from the embedding tensor before distributing. - In the standard DCP load path, when the checkpoint metadata lacks `lm_head.weight`, route the embedding source key into the lm_head tensor and rename back after load. This also covers init load from HF tied-embedding base checkpoints. - Add unit tests for the new helpers and PP-last-stage save behavior. Signed-off-by: khazic <khazzz1c@gmail.com>
Two desensitized example configs for fine-tuning Gemma4 31B VLM: - `gemma4_31b_tp4_pp2.yaml`: single-node 8-GPU (TP4 x PP2 x DP1). - `gemma4_31b_tp4_pp4.yaml`: 4-node 32-GPU (TP4 x PP4 x DP2), with a multi-node `torchrun` launch snippet. Both use the public MedPix-VQA dataset and `google/gemma-4-31B-it`. Signed-off-by: khazic <khazzz1c@gmail.com>
- Sort imports in checkpointing.py and parallelizer.py (ruff I001).
- Replace the import of the MoE-specific
`nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration`
in parallelizer.py with the HF base class from
`transformers.models.gemma4.modeling_gemma4`, which is what dense
Gemma4 VLM actually instantiates. This also lifts the tach
`distributed -> checkpoint` contract violation introduced by the
previous indirect import chain.
- Re-apply ruff format on two touched files in
components/{checkpoint,distributed/pipelining}.
- Update unit tests for the renamed / stricter tied-lm-head semantics:
* `test_model_state_disables_tied_embeddings_for_non_tied_models`
asserts the new `uses_tied_lm_head` / `has_local_tied_lm_head`
attributes instead of the old `is_tied_lm_head`.
* `test_validate_model_with_tied_embeddings` and
`test_validate_multiple_issues` now construct a MockModel with
`lm_head` and `embed_tokens` whose weights actually share storage,
so the validator's stricter "truly tied" check triggers.
* Add a no-op `update_seq_len` to `_MockAutoPipeline` so the
`TestForwardBackwardStepPP` tests can exercise the VLM PP path
that now calls `self.pp.update_seq_len(input_ids.shape[1])`
before each schedule step.
Signed-off-by: khazic <khazzz1c@gmail.com>
Three review fixes:
1. `patch_hf_model_for_pp` previously used
``model.model.language_model is not None`` as the gate for applying
the Gemma4-specific pipeline forward. Many other HF VLMs (KimiVL,
Mistral4, Qwen3 VL MoE, LlavaOneVision, Kimi K25 VL, ...) also nest
a text backbone at that path, so they would have been incorrectly
given Gemma4's sliding/full-attention RoPE dispatch and
final-logit-softcapping logic. Gate on ``config.model_type ==
'gemma4'`` (also checking ``config.text_config.model_type`` for
VLMs that only set it on the inner text config) so unrelated VLMs
fall through to the generic CausalLM forward instead.
2. Document the two shapes of ``_tied_weights_keys`` in
``get_tied_lm_head_source_names``. HF upstream uses a list of
target FQNs tied to the input embedding, while NeMo custom models
set an explicit target->source dict. The list case is already
handled implicitly via ``get_input_embeddings()`` + hardcoded
fallbacks; add a comment so the ``isinstance(tied_keys, dict)``
check doesn't look like an oversight.
3. Add unit tests for the new Gemma4 PP path:
- Gemma4 (via ``config.model_type``) gets the Gemma4 forwards.
- Gemma4 (via ``config.text_config.model_type``) gets the same.
- A non-Gemma4 VLM with ``model.language_model`` falls back to the
generic CausalLM forwards (regression guard).
Signed-off-by: khazic <khazzz1c@gmail.com>
Two related fixes for the L2 integration test failures on the PR: 1. `checkpointing.py` `_maybe_build_consolidated_index` still accessed the renamed `model_state.is_tied_lm_head`, which raised `AttributeError` on save for any tied-embedding model and broke L2_Retrieval / L2_HF_DCP / L2_HF_Transformer_Finetune / L2_Pretrain_and_KD. Switch to `getattr(model_state, "has_local_tied_lm_head", False)` and clarify the intent with a comment. 2. `has_local_tied_lm_head` used an `is`-identity check between `lm_head.weight` and `embed_tokens.weight`. That was too strict: after FSDP2 / TP sharding both weights are wrapped in separate `DTensor`s and `is`-equality is broken even though HF's `tie_weights()` can still relink them on load. The single-rank Gemma3 VL 4B DCP test saw the `lm_head` key leak onto disk because the save path no longer recognized the model as locally tied. Replace the `is` check with `is_tied_word_embeddings(model) and both local weights exist`. This preserves the PP cases (last stage has `lm_head` but no `embed_tokens` => False; first stage has `embed_tokens` but no `lm_head` => False), fixes the single-rank FSDP-sharded case, and keeps the existing unit tests for the PP-last-stage-like partition passing (that mock has no `embed_tokens`, so the check still returns False for it). Signed-off-by: khazic <khazzz1c@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two small follow-up fixes to #1904, caught in review by @HuiyingLi.
1.
local_batch_size: 1 → 2ingemma4_31b_tp4_pp2.yamllbs=1technically runs without errors — the pipeline schedule checksn_microbatches < len(stages)wherelen(stages)is the per-rank stagecount (= 1 for 1F1B), so no warning is triggered. However, with only 1
microbatch across 2 PP stages there is zero pipeline overlap (100% bubble).
My local runs were always done with
lbs=2; the recipe was submitted withlbs=1by mistake for code consistency. Fixed tolocal_batch_size: 2sowe get 2 microbatches and proper 1F1B overlap.
2. Add
embed_visiontoMULTIMODAL_SUFFIXESinhf_utils.pyembed_visionis Gemma4's vision-feature projection layer (insideself.model) and must live on PP Stage 0 alongsidevision_tower. It wasmissing from
MULTIMODAL_SUFFIXESbecause the original PP testing was doneon pure-text inputs — the multimodal path was never exercised, so the
omission slipped through.