Conversation
HF's FineGrainedFP8Config loader materializes the full dequantized BF16
model on every rank before PP split, OOM-ing Devstral under TP+PP on 80 GB
H100 (~77 GB/rank observed at TP=2 PP=2). Introduce a state-dict adapter
modelled on deepseek_v3/state_dict_adapter.py so FP8 dequant runs inside
the standard Checkpointer DCP load path, never materializing the full BF16
model on any rank.
New components:
- nemo_automodel/components/models/devstral/:
- state_dict_adapter.py: DevstralFP8StateDictAdapter. to_hf adds
FP8-typed placeholders + scalar weight_scale_inv so DCP reads FP8
bytes verbatim; from_hf pairs them, dequantizes to bf16, and strips
the VLM language_model. prefix for the 24B variant.
- model.py: Devstral24BFP8TextForCausalLM (24B VLM text backbone,
key_prefix="language_model.") and Devstral123BFP8ForCausalLM (123B
dense, key_prefix=""). Both subclass Ministral3ForCausalLM.
- __init__.py: installs a _resolve_custom_model_cls_for_config hook so
FP8-native Mistral3/Ministral3 configs dispatch to our classes. The
registry is NOT overridden — supports_config gates on quant_method=fp8
so non-FP8 Mistral3 VLMs keep the stock mistral4.model path.
Checkpointer / infra changes:
- checkpointing.py: skip HF's initialize_weights() when the model sets
_skip_init_weights_on_load=True (avoids DTensor collective divergence
across PP stages). Add ministral3 to _MODELS_REQUIRING_BUFFER_REINIT so
non-persistent RoPE inv_freq is recomputed after meta materialization.
- mistral3/model.py: expose rope_init_fn and rope_kwargs on
Ministral3RotaryEmbedding so Pattern-1 reinit can recompute inv_freq.
- parallelizer.py: string-keyed fallback for Mistral3ForConditionalGeneration
layer extraction (class identity can mismatch after transformers reimport).
- recipes/llm/train_ft.py: import devstral module as a side effect so the
resolver hook is installed before the registry is queried.
- quantization/fp8_streaming.py: deprecated shim (empty) — replaced by
the adapter pattern.
Recipes & launchers:
- examples/llm_finetune/devstral/devstral2_small_2512_squad_tp2pp2.yaml:
full-FT SQuAD at TP=2 PP=2 (no quantization_config — adapter owns FP8).
- examples/llm_finetune/devstral/devstral2_123b_hellaswag_tp_pp.yaml:
full-FT HellaSwag at TP=8 PP=8 DP=1, local_batch_size=8 (required for
PP fill), global_batch_size=256, sequence_parallel=false (SP+PP fails
c10d::send on DTensor activations).
- devstral_24b_tp_pp_2node.sub / devstral_123b.sub: multi-node launchers.
Verified end-to-end on 1 node 8xH100 at TP=2 PP=2 for 24B:
step 0: loss 0.8235, grad_norm 149
step 1: loss 0.5015, grad_norm 87
Adapter dequant is bit-identical (max |diff| = 0.000) to HF's
FineGrainedFP8Config(dequantize=True) path across all 363 text-stack
tensors.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Generalize the Devstral-specific FP8 adapter into a family-wide loader that
covers three checkpoint key layouts via one class + one adapter, with
auto-detection from the safetensors weight map:
Layout Disk keys (text) lm_head key Examples
--------------- ------------------------------ -------------------------- ---------------------------
devstral_vlm language_model.model.X language_model.lm_head.X Devstral-Small-2-24B VLM
dawn_ridge_vlm model.language_model.X lm_head.weight Mistral-3.5 128B VLM
dense model.X lm_head.weight Devstral-2-123B dense
Code changes:
- components/models/devstral/state_dict_adapter.py: Mistral3FP8StateDictAdapter
(renamed from DevstralFP8StateDictAdapter). Replaces the key_prefix knob
with (native_to_hf, hf_to_native) callables, preserving the FP8 dequant
logic. Three factory classmethods — for_devstral_vlm / for_dawn_ridge_vlm
/ for_dense — each install the right rewrites. Layout-agnostic dequant
using per-tensor scalar scale_inv (same format across all three
checkpoints).
- components/models/devstral/model.py: single Mistral3FP8ForCausalLM
replaces the two size-specific classes. Accepts Mistral3Config (VLM) or
Ministral3Config (dense). _detect_layout() samples
model.safetensors.index.json at __init__ to pick the adapter factory.
Devstral24BFP8TextForCausalLM and Devstral123BFP8ForCausalLM remain as
backwards-compat aliases.
- components/models/devstral/__init__.py: resolver hook now checks a single
Mistral3FP8ForCausalLM.supports_config that matches any FP8-native
ministral3 (inner or outer) config.
Recipe & launcher for Mistral-3.5 128B ("dawn-ridge"):
- examples/llm_finetune/devstral/mistral3p5_128b_hellaswag_tp_pp.yaml:
full-FT HellaSwag on 8 nodes x 8 H100-80GB, TP=8 PP=8, DP=1,
local_batch_size=8, global_batch_size=256, sequence_parallel=false.
No quantization_config — adapter owns FP8.
- mistral3p5_128b.sub: matching multi-node launcher.
Verified end-to-end on 8 nodes x 8 H100-80GB (job 11298192) — TEXT training
only; vision_tower / multi_modal_projector weights on disk are ignored
because Mistral3FP8ForCausalLM extends Ministral3ForCausalLM (text-only):
detected layout='dawn_ridge_vlm' from dawn-ridge-medium-3p5-128b-hf_vv1
dequantized 77 FP8 weights per PP stage (11 layers * 7 projs)
MEM after load: 3.92 GB / peak 9.32 GB per rank
step 0: loss 2.7985, grad_norm 243.52
step 1: loss 2.2894, grad_norm 58.27
step 2: loss 2.0106, grad_norm 23.51
Adapter correctness regression check on Devstral-24B still passes — all 363
text-stack tensors bit-identical to HF's FineGrainedFP8Config(dequantize=True)
(max |diff| = 0.000).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The 128B dawn-ridge VLM was fine-tuning with a broken initial forward:
step-0 loss landed at ~6.6 (per-token) with grad_norm ~8650, then crashed
back to ~2 after a handful of steps. Root cause: our PP path had never
been wired for Mistral3 VLM end-to-end.
Three gaps are fixed here, each independently necessary to make the run
match the HF reference forward on the same samples:
1. Mistral3 VLM pipeline forward + chunked pixel_values retrieval
(pipelining/hf_utils.py). patch_hf_model_for_pp had a Gemma4-only VLM
branch; Mistral3 fell through to the generic causal_lm forward which
never invokes vision_tower, so image tokens got plain text-embedding
lookups at image_token_id=10 (garbage). Add a Mistral3 VLM forward
that runs vision_tower + multi_modal_projector on stage 0 and routes
hidden states through the text backbone on subsequent stages. Also
retrieve pixel_values/image_sizes from stage0_model._vlm_pixel_values_chunks
mirroring the Gemma4 pattern, since the training loop pops them from
the batch before schedule.step and pre-chunks them per microbatch.
2. Mistral3 VLM TP plan (optimized_tp_plans.py). The parallelizer's
optimized-plan lookup keyed on the qualified class name didn't match
Mistral3ForConditionalGeneration or our FP8 VLM subclass, so it fell
through to a default plan whose module paths (model.layers.*) don't
exist on a VLM (model.language_model.layers.*). Weights stayed
unsharded across TP and the FP8 dequant peaked at 672 MiB per op,
OOMing load time. Register _parallelize_mistral3_vlm for both the
native HF class and the FP8 VLM subclass.
3. Resolver-hook + registry plumbing: entry_cls propagation in
_resolve_custom_model_cls_for_config so AutoModelForImageTextToText
dispatches to the VLM-aware Mistral3FP8VLMForConditionalGeneration
(not the text-only class); MRO walk in _extract_model_layers so the
FP8 subclass inherits HF's layer-path registration; side-effect
import in recipes/vlm/finetune.py so registrations fire before the
model factory is queried. model.py adds the VLM class itself, with a
per-rotary forward pre-hook that recomputes inv_freq on first call
(Ministral3RotaryEmbedding and PixtralRotaryEmbedding precompute in
__init__, so meta-init + to_empty leaves them uninitialized).
Verified:
* Single-rank HF vs ours, 4-layer: bit-identical forward + logits.
* TP=4 PP=2 layer dump on production's first sample (idx=1004) with
production's chunk-flow emulated: loss 12.3132 vs HF 12.3131
(mean per-layer |HF-ours| ≈ 4e-4, textbook bf16 TP drift).
* 8-node production run step 0 after fix: loss 3.2004, grad_norm 932.
Before fix: 6.6255 / 8566. HF reference on the same first batch of
8 samples: 3.4685 — our 3.20 is in the expected ±0.5 band.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…lm_fp8 The FP8 custom-model module was originally added to cover Devstral (24B and 123B text-only checkpoints) alongside the Mistral-3.5 128B VLM (dawn-ridge). We no longer need the Devstral variants, so keep only the VLM class and rename the module to match its actual scope. Removed: - Mistral3FP8ForCausalLM (text-only wrapper around Ministral3ForCausalLM) - Mistral3FP8StateDictAdapter factories: for_dense, for_devstral_vlm, for_dawn_ridge_vlm (and their key-rewrite helpers) - _detect_layout / _resolve_snapshot_dir (checkpoint-layout sniffing) - Text-only branch of the resolver hook (the VLM branch remains) - examples/llm_finetune/devstral/devstral2_* YAMLs - examples/llm_finetune/devstral/mistral3p5_128b_hellaswag_tp_pp.yaml (text-only variant of the 128B recipe) - devstral_123b.sub, devstral_24b_tp_pp_2node.sub, mistral3p5_128b.sub - Devstral side-effect import in recipes/llm/train_ft.py Kept: - Mistral3FP8VLMForConditionalGeneration and for_vlm_full adapter - Resolver hook dispatching VLM via NeMoAutoModelForImageTextToText - MedPix recipe + sbatch (mistral3p5_128b_medpix.sub) — that's the live training path with verified step-0 loss 3.20 (HF ref 3.47). All PP/TP wiring lands on the new module path ``nemo_automodel.components.models.mistral3_vlm_fp8``; the registered TP-plan qualname is updated to match. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Shorten the directory name; the "_fp8" suffix was redundant with the class's own Mistral3FP8* prefix. Update imports in the module itself, the resolver hook attribute, the TP-plan qualname registration, and the side-effect import in recipes/vlm/finetune.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The entry_cls parameter on _resolve_custom_model_cls_for_config existed to let our resolver hook branch text-only vs VLM dispatch for the same FP8 Mistral3 checkpoint. With the text-only Devstral path gone the branch is moot — there's only one custom class left to claim, and the VLM recipe is the sole entry point that imports this module. Revert model_init.py to the original signature and simplify the resolver hook to unconditionally claim matching FP8 Mistral3 VLM configs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The _mem_mark helper + call sites were diagnostic scaffolding added while hunting the OOM-at-checkpoint-load issue on the 128B VLM. Gated behind AUTOMODEL_DEBUG_MEM=1 so dead under default env, but unrelated to the shipped path — the VLM fix is the pipeline_forward_mistral3_vlm retrieval, not memory tracing. Clean up to keep the LLM trainer focused. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Three small cleanups now that text-only Devstral is gone: - components/models/mistral3/model.py: drop `rope_init_fn` / `rope_kwargs` attribute exposure. Added for the checkpointer's _reinit_non_persistent_buffers Pattern-1 path, which was activated via the "ministral3" entry in _MODELS_REQUIRING_BUFFER_REINIT below. Our VLM class handles its own `inv_freq` reinit via per-rotary forward pre-hooks, a different code path, so these attributes were orphan after Devstral removal. - components/checkpoint/checkpointing.py: drop "ministral3" from _MODELS_REQUIRING_BUFFER_REINIT. No shipped recipe reaches that path anymore; the VLM handles rotary reinit inside Mistral3FP8VLMForConditionalGeneration.__init__. Keep the `_skip_init_weights_on_load` gate — that's still load-bearing for the VLM path. - components/distributed/parallelizer.py: drop string fallback entry "Mistral3ForConditionalGeneration" in MODEL_CLS_TO_LAYERS. The MRO walk we added already handles class-identity mismatch by name matching, so the separate string entry was double-covered. Verified with the 1-node 4-layer smoke: step-0 loss 12.4135, 5 steps all complete, identical to pre-cleanup run. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Superseded by the DCP state-dict-adapter load path. No runtime code imports fp8_streaming anymore — the only remaining reference is a debug probe under logs/ (not in the source tree). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Ship the live training recipe for the Mistral-3.5 128B (dawn-ridge) VLM: - examples/vlm_finetune/mistral3p5/mistral3p5_128b_medpix.yaml Full 8-node config (TP=8 PP=8 DP=1, 64 H100-80GB). Runs against the on-disk FP8-native checkpoint via Mistral3FP8VLMForConditionalGeneration + for_vlm_full state-dict adapter. Vision tower + mm_projector frozen; language_model trained. Verified: step-0 loss 3.20 vs HF reference 3.47 on the matching first batch. Also strengthen the _skip_init_weights_on_load comment in the VLM class with an empirical note: without the gate, HF's initialize_weights() hangs indefinitely under PP (verified via a 4-layer smoke that never reached the adapter-load stage within 300s). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The two Mistral3 entries in PARALLELIZE_FUNCTIONS were string-literal qualnames; switch them to _get_class_qualname(Cls) calls to match the prevailing pattern of every other entry. Add the corresponding eager imports for transformers.models.mistral3.Mistral3ForConditionalGeneration and the Automodel Mistral3FP8VLMForConditionalGeneration subclass. Behavior-preserving — verified by re-running the 4-layer single-node VLM smoke (step-0 loss 12.4135, unchanged). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Drop the text_config fallback. HF's Mistral3Config sets ``model_type = "mistral3"`` as a class attribute, so any real Mistral3 VLM (HF native or our subclass) hits the outer-config check first. Behavior-preserving — verified with the 4-layer smoke (step-0 loss 12.4135, unchanged). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…e.py The mistral3_vlm package's __init__.py installs a resolver hook for FP8 Mistral3 configs at import time. Previously we needed an explicit side-effect import in recipes/vlm/finetune.py to ensure that hook fired before the model factory ran. Now that optimized_tp_plans.py imports Mistral3FP8VLMForConditionalGeneration eagerly (for the PARALLELIZE_FUNCTIONS qualname registration), Python imports the parent mistral3_vlm package as a side effect — installing the resolver hook automatically along the parallelizer load chain. The explicit import in finetune.py is now redundant. Verified with the 4-layer smoke: resolver still claims the config, TP plan dispatches, step-0 loss 12.4135 unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Drop the generic MRO walk in _extract_model_layers added for Mistral3
VLM and replace it with the existing pattern used by Gemma4: a
string-keyed entry "Mistral3FP8VLMForConditionalGeneration" in
MODEL_CLS_TO_LAYERS. Restores the original direct/__name__ lookup.
Why a string key (and not a class import): NeMo Auto's
_get_mixin_wrapped_class wraps custom model classes via
type(name, (HFCheckpointingMixin, model_class), {}) — the wrapper has
the same __name__/__module__ but distinct identity, so direct class
membership in MODEL_CLS_TO_LAYERS misses. The elif `__name__ in MAP`
check catches the string key. Same rationale as the existing
"Gemma4ForConditionalGeneration" string fallback above.
Verified: 4-layer single-node smoke step-0 loss 12.4135 (unchanged).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
63 unit tests covering every code change in the Mistral-3.5 128B FP8 VLM PR:
- tests/unit_tests/models/mistral3_vlm/test_state_dict_adapter.py (28 tests)
* _is_fp8_weight_key gating: layer Linear weights, _NON_QUANTIZED_SUFFIXES
exclusions, not_fp8_prefixes prefix-with-dot semantics, substring guard.
* _dequantize_from_fp8 numerical correctness (per-tensor scalar scale).
* for_vlm_full() factory: layout name, not_fp8_prefixes, identity rewrites.
* from_hf: FP8 dequant + scale_inv drop, BF16 passthrough,
activation_scale drop, FP8-without-scale defensive passthrough.
* to_hf: quantization off/on, scale_inv placeholder for FP8 keys,
no placeholder for vision/mm_projector/lm_head/embed, exclude_key_regex.
* convert_single_tensor_to_hf: 1-pair vs 2-pair behavior.
- tests/unit_tests/models/mistral3_vlm/test_model.py (14 tests)
* supports_config matrix (FP8 mistral3 / non-FP8 / non-mistral3 /
dict vs object quantization_config).
* _skip_init_weights_on_load class attribute = True (load-bearing for
the checkpointing PP-deadlock gate).
* __init__ flips quantization_config.dequantize=True (dict + obj forms).
* State-dict adapter attached on construction (vlm_full layout).
* Per-rotary forward pre-hook registered on every inv_freq submodule.
* _rotary_reinit_self_hook: idempotent reinit, swallows init errors.
- tests/unit_tests/models/mistral3_vlm/test_resolver_hook.py (5 tests)
* Resolver hook installation marker + idempotence on re-import.
* Hook claims FP8 mistral3 configs, passes through others.
* supports_config exception is swallowed and original resolver runs.
- tests/unit_tests/distributed/pipelining/test_hf_utils.py (+8 tests)
* _is_mistral3_vlm: model_type='mistral3' / other / no-config.
* patch_hf_model_for_pp dispatches Mistral3 VLM to the right forwards
(outer = pipeline_forward_mistral3_vlm, inner = pipeline_forward).
* pipeline_forward_mistral3_vlm:
- Chunk retrieval fires when pixel_values=None and chunks pre-staged
(this was the original step-0 loss 6.6 → 3.2 fix).
- vision_feature_layer is resolved from config when None (HF outer
forward's @merge_with_config_defaults equivalent).
- Chunk retrieval skipped when input_ids has no image_token_id.
- Non-first stage promotes float input_ids → inputs_embeds.
- tests/unit_tests/distributed/test_optimized_tp_plans.py (+4 tests)
* _parallelize_mistral3_vlm: every text-decoder rule scoped under
model.language_model.* (without this prefix scoping, MLP weights
stayed unsharded under TP and FP8 dequant OOMed).
* Q/K/V/gate/up colwise + O/down rowwise (Ministral3 GQA).
* lm_head colwise at top level (not nested).
* Both Mistral3ForConditionalGeneration AND
Mistral3FP8VLMForConditionalGeneration qualnames registered.
- tests/unit_tests/distributed/test_parallelizer.py (+1 test)
* "Mistral3FP8VLMForConditionalGeneration" string-key entry catches
classes whose runtime identity differs from the imported class
(NeMo Auto's HFCheckpointingMixin wrap creates a fresh type with
same __name__ but distinct identity — direct class match misses).
- tests/unit_tests/checkpoint/test_checkpointing.py (+3 tests)
* _skip_init_weights_on_load=True takes the skip branch (no
initialize_weights() call, _is_hf_initialized untouched).
* False / missing attr falls through to the default init path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Cosmetic-only reflow (no semantic change). Verified by re-running the 63 unit tests added in the previous commit — all pass. Files reformatted: - nemo_automodel/components/distributed/parallelizer.py - nemo_automodel/components/distributed/pipelining/hf_utils.py - nemo_automodel/components/models/mistral3_vlm/model.py - nemo_automodel/components/models/mistral3_vlm/state_dict_adapter.py (checkpointing.py, optimized_tp_plans.py, mistral3_vlm/__init__.py already conform — left unchanged.) ruff check --fix reported no lint fixes needed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Switch model + processor pretrained_model_name_or_path from the local dawn-ridge snapshot to the public release id mistralai/Mistral-Medium-3.5-128B so the recipe runs out of the box for anyone with HF cache access. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
dac48d8 to
18f1f2d
Compare
Contributor
Author
|
/ok to test 18f1f2d |
…H_MAPPING The CPU unit test test_all_model_folders_registered_in_auto_map enforces that every model folder containing a model.py is referenced by at least one entry in MODEL_ARCH_MAPPING. The mistral3_vlm package routes via a resolver hook on _resolve_custom_model_cls_for_config (so non-FP8 Mistral3 VLMs keep the mistral4 path), which left the folder absent from the static map. Add a class-name entry under the actual exported class so the static check passes. The resolver hook remains the live routing path for FP8 configs. Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/ok to test bd82ec5 |
akoumpa
approved these changes
Apr 29, 2026
thomasdhc
approved these changes
Apr 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
https://wandb.ai/Nemo-automodel/huiyingl_workspace/runs/rxh8b2y6?nw=nwuserhuiyingl