Skip to content

fix: AC silently skipped on all registered VLMs — flatten ModuleList #1941

Merged
akoumpa merged 1 commit intoNVIDIA-NeMo:mainfrom
khazic:fix/ac-modulelist-flatten
Apr 21, 2026
Merged

fix: AC silently skipped on all registered VLMs — flatten ModuleList #1941
akoumpa merged 1 commit intoNVIDIA-NeMo:mainfrom
khazic:fix/ac-modulelist-flatten

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented Apr 21, 2026

What does this PR do ?

Fixes a silent activation-checkpointing (AC) regression introduced in #1904 that caused Gemma4 31B to OOM when trained with plain FSDP2 + AC (gemma4_31b.yaml), tracked in #1927.

Root Cause

Why #1904 (TP+PP) broke AC without touching AC code

_extract_model_layers serves dual purpose: both TP sharding and AC consume the same returned layers list:

layers = _extract_model_layers(model)

if tp_mesh.size() > 1:
    for layer in layers:
        parallelize_module(layer, tp_mesh, tp_plan)   # TP uses it

if activation_checkpointing:
    for layer in layers:
        apply_activation_checkpointing(layer, ...)    # AC uses it

To generate a TP sharding plan for Gemma4, #1904 registered it in MODEL_CLS_TO_LAYERS:

Gemma4ForConditionalGeneration: ["model.language_model.layers"],

Before #1904, Gemma4 was not registered, so it fell through to the heuristic path:

elif hasattr(model, "model") and hasattr(model.model, "layers"):
    layers.extend(model.model.layers)   # iterates ModuleList directly → individual layers ✓

After #1904, Gemma4 hits the registered path instead:

layers.extend(_reduce_attrs(model, ["model.language_model.layers"]))

Why the registered path silently breaks AC

_reduce_attrs traverses the model by FQN and returns the module at the end of each path as a single item:

# _reduce_attrs returns:
[<nn.ModuleList of 62 decoder layers>]
#  ↑ a list with ONE element — the whole ModuleList, not individual decoder layers

layers.extend([ModuleList]) adds the entire ModuleList as a single entry. The AC code then iterates layers and looks for self_attn / mlp on each element — nn.ModuleList has neither, so all 62 layers are silently skipped. No checkpointing happens, all activations are retained in memory, and 31B OOMs at step 1.

This bug was pre-existing for all registered VLMs

The same layers.extend(_reduce_attrs(...)) pattern was already in place before #1904 for every model in MODEL_CLS_TO_LAYERS (Qwen2VL, LlavaNext, Mistral3, Llama4, etc.). Their AC had also been silently failing, but with TP/PP memory is distributed across many GPUs so the failure never caused OOM and went unnoticed. Gemma4 31B on plain FSDP2+AC (8×80GB, no TP/PP) was the first configuration where the memory pressure was high enough to surface it.

Fix

Introduce a small helper _extend_layers that flattens any nn.ModuleList results from _reduce_attrs into individual layers:

def _extend_layers(layers, modules):
    for m in modules:
        if isinstance(m, nn.ModuleList):
            layers.extend(m)   # flatten → individual decoder layers
        else:
            layers.append(m)   # non-ModuleList passthrough

This replaces both layers.extend(_reduce_attrs(...)) call sites, fixing AC for all registered model types in one change.

Validation

Reproduced the OOM on gemma4_31b.yaml (8× H100, FSDP2 + AC, tp=1 cp=1) on main before this fix — crashes at step 1 with torch.OutOfMemoryError in v_norm (peak mem jumps to 44.96 GiB then OOM).

After this fix, training runs stably with memory settling at ~40 GiB:

step 0  | loss 3.0765 | grad_norm 107.5000 | mem 36.37 GiB | tps  172.09/gpu
step 1  | loss 3.2145 | grad_norm  91.0000 | mem 40.35 GiB | tps 1296.97/gpu
step 2  | loss 2.6151 | grad_norm  69.5000 | mem 41.82 GiB | tps 1055.56/gpu
step 3  | loss 2.8683 | grad_norm  69.0000 | mem 40.43 GiB | tps 1129.31/gpu
step 4  | loss 3.2896 | grad_norm 134.0000 | mem 40.34 GiB | tps 1530.38/gpu
step 5  | loss 2.3572 | grad_norm  73.5000 | mem 40.34 GiB | tps 1163.57/gpu
step 6  | loss 2.8681 | grad_norm  76.5000 | mem 40.34 GiB | tps 1585.34/gpu
step 7  | loss 3.0434 | grad_norm 194.0000 | mem 40.37 GiB | tps 1511.83/gpu
step 8  | loss 2.4902 | grad_norm 113.0000 | mem 40.34 GiB | tps 1447.45/gpu
step 9  | loss 2.6543 | grad_norm  69.0000 | mem 40.34 GiB | tps 1484.71/gpu
step 10 | loss 2.8736 | grad_norm  66.5000 | mem 40.34 GiB | tps 1360.82/gpu
step 11 | loss 2.3015 | grad_norm  33.0000 | mem 40.37 GiB | tps 1246.52/gpu
step 12 | loss 2.2872 | grad_norm  34.2500 | mem 40.39 GiB | tps 1512.90/gpu
step 13 | loss 2.2140 | grad_norm  38.7500 | mem 40.34 GiB | tps 1472.76/gpu
step 14 | loss 2.0944 | grad_norm  40.7500 | mem 40.41 GiB | tps 1390.58/gpu
step 15 | loss 2.6619 | grad_norm  37.2500 | mem 40.45 GiB | tps 1446.46/gpu
step 16 | loss 2.1999 | grad_norm  44.2500 | mem 40.36 GiB | tps 1448.05/gpu
step 17 | loss 1.7065 | grad_norm 1272.000 | mem 40.33 GiB | tps 1450.97/gpu
step 18 | loss 2.0800 | grad_norm 195.0000 | mem 40.33 GiB | tps 1427.04/gpu
step 19 | loss 2.0801 | grad_norm  26.2500 | mem 40.35 GiB | tps 1374.45/gpu
step 20 | loss 2.3384 | grad_norm  32.2500 | mem 40.34 GiB | tps 1490.00/gpu

Changelog

  • Fix _extract_model_layers to flatten nn.ModuleList objects returned by _reduce_attrs into individual layers, restoring correct AC behavior for all registered VLM model types (Gemma4, Qwen2VL, LlavaNext, Mistral3, Llama4, etc.).

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Additional Information

…ividual layers

_reduce_attrs returns ModuleList objects as single items; extending layers
with them meant AC code never found self_attn/mlp on a ModuleList and
silently skipped all checkpointing. Flatten any ModuleList results so
layers contains individual decoder layers, matching the heuristic path.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@khazic khazic changed the title fix: flatten ModuleList in _extract_model_layers so AC applies to individual layers fix: restore activation checkpointing for FSDP2 VLM training broken by #1904 Apr 21, 2026
@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented Apr 21, 2026

I want to clarify how this bug slipped through in #1904: when testing with TP+PP, the memory pressure per GPU was already low enough (TP4×PP2 spreads the model across 8 GPUs) that training ran for hundreds of steps without OOM — so I assumed AC was working correctly. In reality, AC had already silently failed at that point. I apologize for not catching this earlier. The bug has been fixed in this PR, and I will be more careful to validate AC behavior explicitly in future work, rather than relying on the absence of OOM as a proxy.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test cdd56e5

@HuiyingLi HuiyingLi changed the title fix: restore activation checkpointing for FSDP2 VLM training broken by #1904 fix: AC silently skipped on all registered VLMs — flatten ModuleList in _extract_model_layers Apr 21, 2026
@HuiyingLi HuiyingLi changed the title fix: AC silently skipped on all registered VLMs — flatten ModuleList in _extract_model_layers fix: AC silently skipped on all registered VLMs — flatten ModuleList Apr 21, 2026
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 21, 2026

/claude review

Comment thread nemo_automodel/components/distributed/parallelizer.py
@akoumpa akoumpa added the r0.4.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge. label Apr 21, 2026
@akoumpa akoumpa merged commit 14f14cd into NVIDIA-NeMo:main Apr 21, 2026
65 of 70 checks passed
HuiyingLi added a commit that referenced this pull request Apr 21, 2026
codecov/patch flagged #1941 at 14.28% (1/7 diff lines hit): every existing
test mocks _extract_model_layers, so the new _extend_layers helper and the
two modified call sites were unexecuted. Add six tests over the real
function covering: class-keyed single FQN (GPT2), string-keyed arm
(NemotronH name match), multi-FQN (Qwen2.5-VL), non-ModuleList element
kept as a single entry (ModuleDict post-PP-split shape), and both
ModuleList/ModuleDict fallback branches as regression guards.

Uses Cls.__new__ + nn.Module.__init__ to produce instances whose
type(model) matches the exact class in MODEL_CLS_TO_LAYERS (identity
lookup — subclasses miss the dict) without HF's config-dependent
__init__.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
akoumpa pushed a commit that referenced this pull request Apr 21, 2026
…List (1941)` into `r0.4.0` (#1958)

fix: AC silently skipped on all registered VLMs — flatten ModuleList  (#1941)

fix: flatten ModuleList in _extract_model_layers so AC applies to individual layers

_reduce_attrs returns ModuleList objects as single items; extending layers
with them meant AC code never found self_attn/mlp on a ModuleList and
silently skipped all checkpointing. Flatten any ModuleList results so
layers contains individual decoder layers, matching the heuristic path.

Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Co-authored-by: khazzz1c <khazzz1c@gmail.com>
HuiyingLi added a commit that referenced this pull request Apr 21, 2026
#1941 flattens each ModuleList returned from _reduce_attrs, so
_extract_model_layers now yields individual decoder modules instead
of the containing ModuleLists. Update the two fallback/None-safety
assertions added in #1859 to isinstance-check the inner nn.Linear.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
akoumpa pushed a commit that referenced this pull request Apr 21, 2026
* fix: flatten ModuleList in _extract_model_layers so AC applies to individual layers

_reduce_attrs returns ModuleList objects as single items; extending layers
with them meant AC code never found self_attn/mlp on a ModuleList and
silently skipped all checkpointing. Flatten any ModuleList results so
layers contains individual decoder layers, matching the heuristic path.

Signed-off-by: khazic <khazzz1c@gmail.com>

* test: cover _extract_model_layers flatten branches

codecov/patch flagged #1941 at 14.28% (1/7 diff lines hit): every existing
test mocks _extract_model_layers, so the new _extend_layers helper and the
two modified call sites were unexecuted. Add six tests over the real
function covering: class-keyed single FQN (GPT2), string-keyed arm
(NemotronH name match), multi-FQN (Qwen2.5-VL), non-ModuleList element
kept as a single entry (ModuleDict post-PP-split shape), and both
ModuleList/ModuleDict fallback branches as regression guards.

Uses Cls.__new__ + nn.Module.__init__ to produce instances whose
type(model) matches the exact class in MODEL_CLS_TO_LAYERS (identity
lookup — subclasses miss the dict) without HF's config-dependent
__init__.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* style: ruff format + rename ambiguous `l` loop var

Fix E741 and apply ruff format on the test file.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test(qwen3_5): expect flattened per-layer modules after #1941

#1941 flattens each ModuleList returned from _reduce_attrs, so
_extract_model_layers now yields individual decoder modules instead
of the containing ModuleLists. Update the two fallback/None-safety
assertions added in #1859 to isinstance-check the inner nn.Linear.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

---------

Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: khazic <khazzz1c@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
…1941)

fix: flatten ModuleList in _extract_model_layers so AC applies to individual layers

_reduce_attrs returns ModuleList objects as single items; extending layers
with them meant AC code never found self_attn/mlp on a ModuleList and
silently skipped all checkpointing. Flatten any ModuleList results so
layers contains individual decoder layers, matching the heuristic path.

Signed-off-by: khazic <khazzz1c@gmail.com>
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
* fix: flatten ModuleList in _extract_model_layers so AC applies to individual layers

_reduce_attrs returns ModuleList objects as single items; extending layers
with them meant AC code never found self_attn/mlp on a ModuleList and
silently skipped all checkpointing. Flatten any ModuleList results so
layers contains individual decoder layers, matching the heuristic path.

Signed-off-by: khazic <khazzz1c@gmail.com>

* test: cover _extract_model_layers flatten branches

codecov/patch flagged #1941 at 14.28% (1/7 diff lines hit): every existing
test mocks _extract_model_layers, so the new _extend_layers helper and the
two modified call sites were unexecuted. Add six tests over the real
function covering: class-keyed single FQN (GPT2), string-keyed arm
(NemotronH name match), multi-FQN (Qwen2.5-VL), non-ModuleList element
kept as a single entry (ModuleDict post-PP-split shape), and both
ModuleList/ModuleDict fallback branches as regression guards.

Uses Cls.__new__ + nn.Module.__init__ to produce instances whose
type(model) matches the exact class in MODEL_CLS_TO_LAYERS (identity
lookup — subclasses miss the dict) without HF's config-dependent
__init__.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* style: ruff format + rename ambiguous `l` loop var

Fix E741 and apply ruff format on the test file.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test(qwen3_5): expect flattened per-layer modules after #1941

#1941 flattens each ModuleList returned from _reduce_attrs, so
_extract_model_layers now yields individual decoder modules instead
of the containing ModuleLists. Update the two fallback/None-safety
assertions added in #1859 to isinstance-check the inner nn.Linear.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

---------

Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: khazic <khazzz1c@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request r0.4.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants