Skip to content

fix: Gemma4 PP follow-ups — embed_vision stage-0 assignment and lbs=2 for PP2 recipe#1910

Closed
khazic wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/gemma4-vlm-tp-pp
Closed

fix: Gemma4 PP follow-ups — embed_vision stage-0 assignment and lbs=2 for PP2 recipe#1910
khazic wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/gemma4-vlm-tp-pp

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented Apr 20, 2026

Summary

Two small follow-up fixes to #1904, caught in review by @HuiyingLi.

1. local_batch_size: 1 → 2 in gemma4_31b_tp4_pp2.yaml

lbs=1 technically runs without errors — the pipeline schedule checks
n_microbatches < len(stages) where len(stages) is the per-rank stage
count (= 1 for 1F1B), so no warning is triggered. However, with only 1
microbatch across 2 PP stages there is zero pipeline overlap (100% bubble).

My local runs were always done with lbs=2; the recipe was submitted with
lbs=1 by mistake for code consistency. Fixed to local_batch_size: 2 so
we get 2 microbatches and proper 1F1B overlap.

2. Add embed_vision to MULTIMODAL_SUFFIXES in hf_utils.py

embed_vision is Gemma4's vision-feature projection layer (inside
self.model) and must live on PP Stage 0 alongside vision_tower. It was
missing from MULTIMODAL_SUFFIXES because the original PP testing was done
on pure-text inputs — the multimodal path was never exercised, so the
omission slipped through.

khazic added 8 commits April 18, 2026 15:04
Enable tensor- and pipeline-parallel fine-tuning of Gemma4
ConditionalGeneration models:

- Register Gemma4ForConditionalGeneration in `_extract_model_layers`
  and `validate_tp_mesh` so the parallelizer recognizes the VLM stack.
- Add pipeline forwards for the Gemma4 text backbone and VLM outer
  module, including:
  * RoPE dispatch via `config.layer_types[layer_idx]`
  * Skipping `tie_weights` on PP stages that lack `embed_tokens` or
    `lm_head`
  * Handling `image_position_ids` in PP VLM chunking

Signed-off-by: khazic <khazzz1c@gmail.com>
PP last stages own `lm_head` but not `embed_tokens`, so the lm_head
tensor cannot share storage with its tied source. HF tied-embedding
checkpoints (and DCP checkpoints saved before this fix) omit
`lm_head.weight` entirely, which made the DCP planner fail with
`Missing key in checkpoint state_dict: lm_head.weight` whenever the
last PP stage tried to load.

Changes:
- Distinguish `uses_tied_lm_head` (config flag) from
  `has_local_tied_lm_head` (the local stage actually shares the same
  tensor) in `ModelState`. Saved state dicts now keep `lm_head.weight`
  on PP last stages.
- Add `materialize_missing_tied_lm_head` and
  `get_tied_lm_head_source_names` helpers that resolve the tied source
  via `_tied_weights_keys` with sensible HF fallbacks.
- In the safetensors fast load path, materialize the missing
  `lm_head.weight` from the embedding tensor before distributing.
- In the standard DCP load path, when the checkpoint metadata lacks
  `lm_head.weight`, route the embedding source key into the lm_head
  tensor and rename back after load. This also covers init load from
  HF tied-embedding base checkpoints.
- Add unit tests for the new helpers and PP-last-stage save behavior.

Signed-off-by: khazic <khazzz1c@gmail.com>
Two desensitized example configs for fine-tuning Gemma4 31B VLM:

- `gemma4_31b_tp4_pp2.yaml`: single-node 8-GPU (TP4 x PP2 x DP1).
- `gemma4_31b_tp4_pp4.yaml`: 4-node 32-GPU (TP4 x PP4 x DP2), with a
  multi-node `torchrun` launch snippet.

Both use the public MedPix-VQA dataset and `google/gemma-4-31B-it`.

Signed-off-by: khazic <khazzz1c@gmail.com>
- Sort imports in checkpointing.py and parallelizer.py (ruff I001).
- Replace the import of the MoE-specific
  `nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration`
  in parallelizer.py with the HF base class from
  `transformers.models.gemma4.modeling_gemma4`, which is what dense
  Gemma4 VLM actually instantiates. This also lifts the tach
  `distributed -> checkpoint` contract violation introduced by the
  previous indirect import chain.
- Re-apply ruff format on two touched files in
  components/{checkpoint,distributed/pipelining}.
- Update unit tests for the renamed / stricter tied-lm-head semantics:
  * `test_model_state_disables_tied_embeddings_for_non_tied_models`
    asserts the new `uses_tied_lm_head` / `has_local_tied_lm_head`
    attributes instead of the old `is_tied_lm_head`.
  * `test_validate_model_with_tied_embeddings` and
    `test_validate_multiple_issues` now construct a MockModel with
    `lm_head` and `embed_tokens` whose weights actually share storage,
    so the validator's stricter "truly tied" check triggers.
  * Add a no-op `update_seq_len` to `_MockAutoPipeline` so the
    `TestForwardBackwardStepPP` tests can exercise the VLM PP path
    that now calls `self.pp.update_seq_len(input_ids.shape[1])`
    before each schedule step.

Signed-off-by: khazic <khazzz1c@gmail.com>
Three review fixes:

1. `patch_hf_model_for_pp` previously used
   ``model.model.language_model is not None`` as the gate for applying
   the Gemma4-specific pipeline forward. Many other HF VLMs (KimiVL,
   Mistral4, Qwen3 VL MoE, LlavaOneVision, Kimi K25 VL, ...) also nest
   a text backbone at that path, so they would have been incorrectly
   given Gemma4's sliding/full-attention RoPE dispatch and
   final-logit-softcapping logic. Gate on ``config.model_type ==
   'gemma4'`` (also checking ``config.text_config.model_type`` for
   VLMs that only set it on the inner text config) so unrelated VLMs
   fall through to the generic CausalLM forward instead.

2. Document the two shapes of ``_tied_weights_keys`` in
   ``get_tied_lm_head_source_names``. HF upstream uses a list of
   target FQNs tied to the input embedding, while NeMo custom models
   set an explicit target->source dict. The list case is already
   handled implicitly via ``get_input_embeddings()`` + hardcoded
   fallbacks; add a comment so the ``isinstance(tied_keys, dict)``
   check doesn't look like an oversight.

3. Add unit tests for the new Gemma4 PP path:
   - Gemma4 (via ``config.model_type``) gets the Gemma4 forwards.
   - Gemma4 (via ``config.text_config.model_type``) gets the same.
   - A non-Gemma4 VLM with ``model.language_model`` falls back to the
     generic CausalLM forwards (regression guard).

Signed-off-by: khazic <khazzz1c@gmail.com>
Two related fixes for the L2 integration test failures on the PR:

1. `checkpointing.py` `_maybe_build_consolidated_index` still accessed
   the renamed `model_state.is_tied_lm_head`, which raised
   `AttributeError` on save for any tied-embedding model and broke
   L2_Retrieval / L2_HF_DCP / L2_HF_Transformer_Finetune /
   L2_Pretrain_and_KD. Switch to
   `getattr(model_state, "has_local_tied_lm_head", False)` and clarify
   the intent with a comment.

2. `has_local_tied_lm_head` used an `is`-identity check between
   `lm_head.weight` and `embed_tokens.weight`. That was too strict:
   after FSDP2 / TP sharding both weights are wrapped in separate
   `DTensor`s and `is`-equality is broken even though HF's
   `tie_weights()` can still relink them on load. The single-rank
   Gemma3 VL 4B DCP test saw the `lm_head` key leak onto disk because
   the save path no longer recognized the model as locally tied.

   Replace the `is` check with `is_tied_word_embeddings(model) and
   both local weights exist`. This preserves the PP cases (last stage
   has `lm_head` but no `embed_tokens` => False; first stage has
   `embed_tokens` but no `lm_head` => False), fixes the single-rank
   FSDP-sharded case, and keeps the existing unit tests for the
   PP-last-stage-like partition passing (that mock has no
   `embed_tokens`, so the check still returns False for it).

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 20, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@khazic khazic closed this Apr 20, 2026
@khazic khazic reopened this Apr 20, 2026
@khazic khazic closed this Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants