Skip to content

feat: TP+PP support for Gemma4 VLM (with tied lm_head fix and 31B recipes)#1904

Merged
akoumpa merged 6 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/gemma4-vlm-tp-pp
Apr 19, 2026
Merged

feat: TP+PP support for Gemma4 VLM (with tied lm_head fix and 31B recipes)#1904
akoumpa merged 6 commits intoNVIDIA-NeMo:mainfrom
khazic:feat/gemma4-vlm-tp-pp

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented Apr 18, 2026

Summary

  • Add TP+PP support for Gemma4 ConditionalGeneration VLM:
    register the model in the parallelizer and add Gemma4-specific
    pipeline forwards (text backbone + VLM outer), including
    layer_types-based RoPE dispatch, safe tie_weights skipping for
    PP stages without embed_tokens/lm_head, and PP chunking that
    handles image_position_ids.
  • Fix tied lm_head loading on the PP last stage. PP last stages own
    lm_head but not embed_tokens, so they cannot share storage
    locally. HF tied-embedding checkpoints (and DCP checkpoints saved
    before this fix) omit lm_head.weight entirely, which made the DCP
    planner fail with Missing key in checkpoint state_dict: lm_head.weight. The fix:
    • Splits uses_tied_lm_head (config flag) from
      has_local_tied_lm_head (true only when local lm_head and
      embed_tokens actually share storage). Saved state dicts now keep
      lm_head.weight on PP last stages.
    • Adds materialize_missing_tied_lm_head /
      get_tied_lm_head_source_names helpers using
      _tied_weights_keys plus HF fallbacks.
    • In the safetensors fast load path, materializes the missing
      lm_head.weight from the embedding tensor.
    • In the standard DCP load path, when lm_head.weight is absent
      from checkpoint metadata, routes the embedding source key into
      the lm_head tensor and renames back after load. This now also
      applies on init load so a fresh HF tied-embedding base loads
      cleanly into a PP last stage.
  • Add two desensitized example recipes for Gemma4 31B VLM:
    • examples/vlm_finetune/gemma4/gemma4_31b_tp4_pp2.yaml — single
      node, 8 GPUs (TP4 x PP2 x DP1).
    • examples/vlm_finetune/gemma4/gemma4_31b_tp4_pp4.yaml — 4 nodes,
      32 GPUs (TP4 x PP4 x DP2), with a multi-node torchrun snippet.
      Both use google/gemma-4-31B-it and the public MedPix-VQA dataset.

Test plan

  • Added/updated unit tests in
    tests/unit_tests/checkpoint/test_checkpointing.py
    (has_local_tied_lm_head, materialize_missing_tied_lm_head,
    PP-last-stage save behavior).
  • Manual run of gemma4_31b_tp4_pp2.yaml (single node, 8 GPU).
  • Manual run of gemma4_31b_tp4_pp4.yaml (4 nodes, 32 GPU).
  • CI: ruff, unit tests, recipe smoke tests.

khazic added 3 commits April 18, 2026 15:04
Enable tensor- and pipeline-parallel fine-tuning of Gemma4
ConditionalGeneration models:

- Register Gemma4ForConditionalGeneration in `_extract_model_layers`
  and `validate_tp_mesh` so the parallelizer recognizes the VLM stack.
- Add pipeline forwards for the Gemma4 text backbone and VLM outer
  module, including:
  * RoPE dispatch via `config.layer_types[layer_idx]`
  * Skipping `tie_weights` on PP stages that lack `embed_tokens` or
    `lm_head`
  * Handling `image_position_ids` in PP VLM chunking

Signed-off-by: khazic <khazzz1c@gmail.com>
PP last stages own `lm_head` but not `embed_tokens`, so the lm_head
tensor cannot share storage with its tied source. HF tied-embedding
checkpoints (and DCP checkpoints saved before this fix) omit
`lm_head.weight` entirely, which made the DCP planner fail with
`Missing key in checkpoint state_dict: lm_head.weight` whenever the
last PP stage tried to load.

Changes:
- Distinguish `uses_tied_lm_head` (config flag) from
  `has_local_tied_lm_head` (the local stage actually shares the same
  tensor) in `ModelState`. Saved state dicts now keep `lm_head.weight`
  on PP last stages.
- Add `materialize_missing_tied_lm_head` and
  `get_tied_lm_head_source_names` helpers that resolve the tied source
  via `_tied_weights_keys` with sensible HF fallbacks.
- In the safetensors fast load path, materialize the missing
  `lm_head.weight` from the embedding tensor before distributing.
- In the standard DCP load path, when the checkpoint metadata lacks
  `lm_head.weight`, route the embedding source key into the lm_head
  tensor and rename back after load. This also covers init load from
  HF tied-embedding base checkpoints.
- Add unit tests for the new helpers and PP-last-stage save behavior.

Signed-off-by: khazic <khazzz1c@gmail.com>
Two desensitized example configs for fine-tuning Gemma4 31B VLM:

- `gemma4_31b_tp4_pp2.yaml`: single-node 8-GPU (TP4 x PP2 x DP1).
- `gemma4_31b_tp4_pp4.yaml`: 4-node 32-GPU (TP4 x PP4 x DP2), with a
  multi-node `torchrun` launch snippet.

Both use the public MedPix-VQA dataset and `google/gemma-4-31B-it`.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented Apr 18, 2026

Training log from a live run on the 4-node TP4 x PP4 x DP2 recipe (Gemma4 31B VLM) on our side — loss is converging and grad norm / tps look stable:

2026-04-18 06:57:50 | INFO | root | step 247 | epoch 0 | loss 0.2537 | grad_norm 12.3819 | lr 1.99e-05 | mem 29.18 GiB | tps 4006.84(2003.42/gpu) | num_label_tokens 138
2026-04-18 06:57:52 | INFO | root | step 248 | epoch 0 | loss 0.1538 | grad_norm 12.6614 | lr 1.99e-05 | mem 29.18 GiB | tps 3973.81(1986.90/gpu) | num_label_tokens 220
2026-04-18 06:57:54 | INFO | root | step 249 | epoch 0 | loss 0.2785 | grad_norm 9.9974  | lr 1.99e-05 | mem 29.18 GiB | tps 4010.05(2005.02/gpu) | num_label_tokens 216
2026-04-18 06:57:56 | INFO | root | step 250 | epoch 0 | loss 0.2776 | grad_norm 9.7785  | lr 1.99e-05 | mem 29.18 GiB | tps 3983.61(1991.81/gpu) | num_label_tokens 273
2026-04-18 06:57:59 | INFO | root | step 251 | epoch 0 | loss 0.3938 | grad_norm 16.7286 | lr 1.99e-05 | mem 29.18 GiB | tps 3296.05(1648.03/gpu) | num_label_tokens 270
2026-04-18 06:58:01 | INFO | root | step 252 | epoch 0 | loss 0.4191 | grad_norm 13.9407 | lr 1.99e-05 | mem 29.18 GiB | tps 3912.06(1956.03/gpu) | num_label_tokens 287
2026-04-18 06:58:03 | INFO | root | step 253 | epoch 0 | loss 0.0856 | grad_norm 18.3720 | lr 1.99e-05 | mem 29.18 GiB | tps 4002.60(2001.30/gpu) | num_label_tokens 284
2026-04-18 06:58:05 | INFO | root | step 254 | epoch 0 | loss 0.1639 | grad_norm 6.9901  | lr 1.99e-05 | mem 29.18 GiB | tps 3984.99(1992.49/gpu) | num_label_tokens 439

Config: 4 nodes x 8 GPUs (32 GPUs total), TP=4, PP=4, DP=2, pp_microbatch_size=1, 1f1b schedule, activation_checkpointing=true, bf16, Gemma4 31B VLM. Memory is steady at ~29 GiB/GPU and per-GPU tps is around 2k. Training started cleanly after the tied lm_head init fix in this PR — without it, the PP last stage failed with Missing key in checkpoint state_dict: lm_head.weight during base model load.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 1ee4789

@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

Comment on lines +458 to +472
inner_model = getattr(model, "model", None)
text_backbone = (
getattr(inner_model, "language_model", None)
if inner_model is not None
else None
)

if inner_model is not None and text_backbone is not None:
# VLM with nested text backbone (e.g. Gemma4): patch text backbone and VLM outer
if patch_inner_model:
text_backbone.forward = types.MethodType(
create_pipeline_forward_gemma4_text(), text_backbone
)
if patch_causal_lm_model:
model.forward = types.MethodType(create_pipeline_forward_gemma4_vlm(), model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: This language_model check is not specific to Gemma4 — many other VLMs (KimiVL, Mistral4, Qwen3VL MoE, Kimi K25 VL, LlavaOneVision, etc.) also have model.language_model. As written, all of those models will get the Gemma4-specific create_pipeline_forward_gemma4_text / create_pipeline_forward_gemma4_vlm patched onto them during PP, which will break them.

This needs to be gated on the model actually being Gemma4, e.g.:

from nemo_automodel.components.models.gemma4_moe.model import Gemma4ForConditionalGeneration

is_gemma4_vlm = isinstance(model, Gemma4ForConditionalGeneration)
# or check model.config.model_type == "gemma4" if you want to avoid the import

And then the branch becomes if inner_model is not None and is_gemma4_vlm:.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @khazic what do you think about this one?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @akoumpa — good catch, this was already addressed in a follow-up commit cf772a4d

Comment on lines +129 to +131
tied_keys = getattr(model, "_tied_weights_keys", None)
if isinstance(tied_keys, dict):
for target_name, source_name in tied_keys.items():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit/potential bug: HF upstream defines _tied_weights_keys as a list of strings (e.g., ["lm_head.weight"]), not a dict. The NeMo custom models use a dict convention ({"lm_head.weight": "model.embed_tokens.weight"}), but Gemma4ForConditionalGeneration doesn't define _tied_weights_keys at all, so it will inherit the HF parent's list.

This isinstance(tied_keys, dict) check will silently skip the HF list case. The hardcoded fallback candidates below save this in practice, but the code doesn't match its stated intent. Consider also handling list/set types here (where items are target names and you'd need to infer the source differently), or at minimum add a comment noting this is intentionally dict-only for NeMo models.



# =============================================================================
# Tests for _reinit_non_persistent_buffers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test coverage: The patch_hf_model_for_pp changes in hf_utils.py add a new VLM branch (model with model.language_model), but the existing TestPatchHfModelForPp tests don't cover this case. Given the regression risk (the new branch currently activates for all VLMs with language_model, not just Gemma4), this needs test coverage.

- Sort imports in checkpointing.py and parallelizer.py (ruff I001).
- Replace the import of the MoE-specific
  `nemo_automodel.components.models.gemma4_moe.model.Gemma4ForConditionalGeneration`
  in parallelizer.py with the HF base class from
  `transformers.models.gemma4.modeling_gemma4`, which is what dense
  Gemma4 VLM actually instantiates. This also lifts the tach
  `distributed -> checkpoint` contract violation introduced by the
  previous indirect import chain.
- Re-apply ruff format on two touched files in
  components/{checkpoint,distributed/pipelining}.
- Update unit tests for the renamed / stricter tied-lm-head semantics:
  * `test_model_state_disables_tied_embeddings_for_non_tied_models`
    asserts the new `uses_tied_lm_head` / `has_local_tied_lm_head`
    attributes instead of the old `is_tied_lm_head`.
  * `test_validate_model_with_tied_embeddings` and
    `test_validate_multiple_issues` now construct a MockModel with
    `lm_head` and `embed_tokens` whose weights actually share storage,
    so the validator's stricter "truly tied" check triggers.
  * Add a no-op `update_seq_len` to `_MockAutoPipeline` so the
    `TestForwardBackwardStepPP` tests can exercise the VLM PP path
    that now calls `self.pp.update_seq_len(input_ids.shape[1])`
    before each schedule step.

Signed-off-by: khazic <khazzz1c@gmail.com>
Two related fixes for the L2 integration test failures on the PR:

1. `checkpointing.py` `_maybe_build_consolidated_index` still accessed
   the renamed `model_state.is_tied_lm_head`, which raised
   `AttributeError` on save for any tied-embedding model and broke
   L2_Retrieval / L2_HF_DCP / L2_HF_Transformer_Finetune /
   L2_Pretrain_and_KD. Switch to
   `getattr(model_state, "has_local_tied_lm_head", False)` and clarify
   the intent with a comment.

2. `has_local_tied_lm_head` used an `is`-identity check between
   `lm_head.weight` and `embed_tokens.weight`. That was too strict:
   after FSDP2 / TP sharding both weights are wrapped in separate
   `DTensor`s and `is`-equality is broken even though HF's
   `tie_weights()` can still relink them on load. The single-rank
   Gemma3 VL 4B DCP test saw the `lm_head` key leak onto disk because
   the save path no longer recognized the model as locally tied.

   Replace the `is` check with `is_tied_word_embeddings(model) and
   both local weights exist`. This preserves the PP cases (last stage
   has `lm_head` but no `embed_tokens` => False; first stage has
   `embed_tokens` but no `lm_head` => False), fixes the single-rank
   FSDP-sharded case, and keeps the existing unit tests for the
   PP-last-stage-like partition passing (that mock has no
   `embed_tokens`, so the check still returns False for it).

Signed-off-by: khazic <khazzz1c@gmail.com>
@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented Apr 18, 2026

Pushed 7c34c9fb to address the L2 failures. Two issues rooted in the tied-lm_head rename:

  1. checkpointing.py:_maybe_build_consolidated_index still referenced model_state.is_tied_lm_head, which I renamed to uses_tied_lm_head / split into has_local_tied_lm_head. That raised AttributeError during save on every tied-embedding model — the root cause for L2_Retrieval / L2_HF_Transformer_Finetune / L2_Pretrain_and_KD, and for most of the L2_HF_DCP failures. Switched it to has_local_tied_lm_head (only drop lm_head from the save map when it is actually an alias of the embedding; PP last stages keep their own lm_head).

  2. has_local_tied_lm_head was too strict: it required lm_head.weight is embed_tokens.weight, which is broken by FSDP2 / TP sharding (DTensor wrapping breaks tensor identity). For single-rank Gemma3 VL 4B this made has_local_tied_lm_head=False after sharding, so the save path stopped dropping lm_head, and test_dcp_vlm::test_vlm_dcp_checkpoint saw a stray lm_head.weight key on disk. Relaxed to "config says tied AND both local weights exist"; that matches the old behavior for single-rank and still returns False for PP partitions missing one side.

Verified locally: ruff check ., ruff format --check ., lint-imports all clean; existing test_has_local_tied_lm_head_is_false_for_pp_last_stage_like_partition / test_model_state_keeps_pp_last_stage_lm_head_in_saved_state_dict still pass under the new definition.

/ok to test 7c34c9fb... when you have a moment — thanks!

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 18, 2026

/ok to test 7c34c9f

Copy link
Copy Markdown
Contributor

@akoumpa akoumpa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @khazic !


step_scheduler:
global_batch_size: 8
local_batch_size: 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khazic with pp2 the lbs shouldn't be 1?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Technically lbs=1 still runs without errors — the code checks n_microbatches < len(stages) where len(stages) is the per-rank stage count (= 1 for 1F1B), so no warning is triggered. But with only 1 microbatch for 2 PP stages there is zero pipeline overlap (100% bubble). Fixed to local_batch_size: 2 so we get 2 microbatches and proper 1F1B overlap.

"multi_modal_projector",
"multimodal_projector",
"vision_projector",
"audio_projector",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khazic missing embed_vision for gemma4 for pp?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! When testing I was only running pure-text inputs, so the multimodal path wasn't exercised and embed_vision slipped through. Fixed by adding it to MULTIMODAL_SUFFIXES in hf_utils.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants