Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 63 additions & 69 deletions docker/common/uv-pytorch.lock

Large diffs are not rendered by default.

100 changes: 91 additions & 9 deletions nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,15 @@ def initialize_model_weights(
device: Target device for materialized parameters.
peft_init_method: Initialization method for PEFT adapters (e.g. "xavier").
"""
to_empty_parameters_only(model, device=device)
# Only materialize parameters that are actually on the meta device.
# When the caller sets is_meta_device=True but the model was already
# constructed on a real device (e.g. ContextManagers was patched to
# a no-op), calling to_empty_parameters_only would replace valid
# weights with uninitialized CUDA memory.
has_meta_params = any(p.device.type == "meta" for p in model.parameters())
if has_meta_params:
to_empty_parameters_only(model, device=device)

# to_empty_parameters_only only materializes parameters, not buffers.
# Buffers (e.g. RoPE inv_freq) may still be on meta device. Move them
# to *device* with uninitialized storage so that the subsequent
# initialize_weights() call can overwrite them with proper values
Expand Down Expand Up @@ -580,10 +586,11 @@ def load_base_model(
model_name: Name of the model or an absolute path to a snapshot
load_base_model: If True, restore from HF base checkpoint
"""
model_type = getattr(getattr(model, "config", None), "model_type", None)

if load_base_model:
assert model_name is not None, "model_name is required when loading base model"
# Get combined key mapping from model attribute and model-type specific conversions
model_type = getattr(getattr(model, "config", None), "model_type", None)
model_key_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
key_mapping = get_combined_key_mapping(model_type, model_key_mapping)
# NemotronH remote code (trust_remote_code) uses backbone.* params matching checkpoint keys
Expand All @@ -599,7 +606,7 @@ def load_base_model(
key_mapping=key_mapping,
)

_reinit_rope_buffers(model, device)
_reinit_non_persistent_buffers(model, device, model_type=model_type)

is_tied_lm_head = is_tied_word_embeddings(model)
self.config.original_model_root_dir = root_dir
Expand Down Expand Up @@ -1025,18 +1032,48 @@ def _init_peft_adapters(model: nn.Module, peft_init_method: str) -> None:
logging.warning(f"Failed to initialize weights for PEFT adapter `{module.__class__.__name__}`: {e}")


def _reinit_rope_buffers(model: nn.Module, device: torch.device) -> None:
_MODELS_REQUIRING_BUFFER_REINIT: frozenset[str] = frozenset(
{
"gemma3",
"nemotron-nas",
}
)


def _reinit_non_persistent_buffers(model: nn.Module, device: torch.device, model_type: str | None = None) -> None:
"""
Recompute non-persistent RoPE ``inv_freq`` buffers for Nemotron-NAS models.
Recompute non-persistent buffers that are not saved in checkpoints.

Non-persistent buffers are not saved in checkpoints, so after meta-device
materialization they contain uninitialized CUDA memory. When
``initialize_weights()`` is skipped (e.g. for Gemma3 to avoid DTensor
issues), these buffers must be recomputed explicitly.

Only runs for models listed in ``_MODELS_REQUIRING_BUFFER_REINIT`` to
avoid unexpected side-effects on arbitrary HF Hub models.

Handles four patterns:

1. **Standard RoPE** — single ``inv_freq`` buffer with ``rope_init_fn`` +
``rope_kwargs`` (e.g. Nemotron-NAS).
2. **Per-layer-type RoPE** — ``{layer_type}_inv_freq`` buffers via
``compute_default_rope_parameters`` (e.g. Gemma3RotaryEmbedding).
3. **Scaled embedding** — ``embed_scale`` buffer on ``ScaledWordEmbedding``
modules (Gemma family), recomputed from ``scalar_embed_scale``.
4. **Vision position IDs** — ``position_ids`` buffer on vision embedding
modules (SigLIP), recomputed from ``num_positions``.

Args:
model: Model to reinitialize RoPE buffers for.
model: Model to reinitialize non-persistent buffers for.
device: Device to create the new buffers on.
model_type: The ``config.model_type`` string. If not in
``_MODELS_REQUIRING_BUFFER_REINIT`` the function is a no-op.
"""
model_type = getattr(getattr(model, "config", None), "model_type", None)
if model_type not in ("nemotron-nas",):
if model_type not in _MODELS_REQUIRING_BUFFER_REINIT:
return

for name, module in model.named_modules():
# Pattern 1: standard RoPE with rope_init_fn + rope_kwargs (Nemotron-NAS)
if hasattr(module, "rope_init_fn") and hasattr(module, "inv_freq") and hasattr(module, "rope_kwargs"):
try:
inv_freq, _ = module.rope_init_fn(module.config, device, **module.rope_kwargs)
Expand All @@ -1047,6 +1084,51 @@ def _reinit_rope_buffers(model: nn.Module, device: torch.device) -> None:
except Exception as e:
logging.warning(f"Failed to reinitialize RoPE inv_freq for {name}: {e}")

# Pattern 2: per-layer-type RoPE (Gemma3RotaryEmbedding and similar)
elif hasattr(module, "layer_types") and hasattr(module, "rope_type") and hasattr(module, "config"):
rope_config = getattr(module, "config", None)
rope_parameters = getattr(rope_config, "rope_parameters", None)
if rope_parameters is None:
continue
for layer_type in getattr(module, "layer_types", []):
inv_freq_attr = f"{layer_type}_inv_freq"
if not hasattr(module, inv_freq_attr):
continue
try:
rope_init_fn = getattr(module, "compute_default_rope_parameters", None)
if rope_init_fn is None:
continue
rope_type = module.rope_type.get(layer_type, "default")
if rope_type != "default":
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS

rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
curr_inv_freq, curr_attention_scaling = rope_init_fn(rope_config, device, layer_type=layer_type)
setattr(module, inv_freq_attr, curr_inv_freq)
orig_attr = f"{layer_type}_original_inv_freq"
if hasattr(module, orig_attr):
setattr(module, orig_attr, curr_inv_freq.clone())
setattr(module, f"{layer_type}_attention_scaling", curr_attention_scaling)
logging.debug(f"Reinitialized RoPE {inv_freq_attr} for {name} on device {device}")
except Exception as e:
logging.warning(f"Failed to reinitialize RoPE {inv_freq_attr} for {name}: {e}")

# Pattern 3: ScaledWordEmbedding embed_scale (Gemma family)
if hasattr(module, "scalar_embed_scale") and "embed_scale" in getattr(module, "_buffers", {}):
try:
module.embed_scale = torch.tensor(module.scalar_embed_scale, device=device)
logging.debug(f"Reinitialized embed_scale={module.scalar_embed_scale} for {name} on device {device}")
except Exception as e:
logging.warning(f"Failed to reinitialize embed_scale for {name}: {e}")

# Pattern 4: Vision embedding position_ids (SigLIP and similar)
if hasattr(module, "num_positions") and "position_ids" in getattr(module, "_buffers", {}):
try:
module.position_ids = torch.arange(module.num_positions, device=device).expand((1, -1))
logging.debug(f"Reinitialized position_ids (num_positions={module.num_positions}) for {name}")
except Exception as e:
logging.warning(f"Failed to reinitialize position_ids for {name}: {e}")


def _apply(module, fn, recurse=True) -> nn.Module:
"""
Expand Down
17 changes: 17 additions & 0 deletions nemo_automodel/components/checkpoint/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def get_model_conversion_mapping(
)


_VLM_KEY_MAPPINGS: dict[str, dict[str, str]] = {
"gemma3": {
r"^language_model\.model\.": "model.language_model.",
r"^vision_tower\.": "model.vision_tower.",
r"^multi_modal_projector\.": "model.multi_modal_projector.",
r"^language_model\.lm_head\.": "lm_head.",
},
}


def get_combined_key_mapping(
model_type: str,
model_key_mapping: Optional[dict[str, str]] = None,
Expand All @@ -188,6 +198,13 @@ def get_combined_key_mapping(
Combined key mapping dictionary (regex pattern -> replacement),
or None if no mappings are defined.
"""
# VLM models with known restructured hierarchies get explicit mappings
# that override the generic transformers conversion (e.g. transformers 5.5.0
# aliases gemma3→llava, but the llava mapping produces wrong FQNs for
# Gemma3's model.language_model.* hierarchy).
if model_type in _VLM_KEY_MAPPINGS:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note-to-self: I think this is ok for the time being, but long term, it might be a good idea to move all the model-specific patching under nemo_automodel/components/models/, so that each model has its own patches. No action is needed.

return dict(_VLM_KEY_MAPPINGS[model_type])

result = {}

# First add model-specific key mapping (takes precedence)
Expand Down
7 changes: 5 additions & 2 deletions nemo_automodel/components/models/gemma4_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def __init__(
moe_defaults = dict(
dim=config.hidden_size,
inter_dim=config.intermediate_size,
moe_inter_dim=config.expert_intermediate_size or getattr(config, "moe_intermediate_size", None),
moe_inter_dim=getattr(config, "moe_intermediate_size", None)
or getattr(config, "expert_intermediate_size", None),
n_routed_experts=config.num_experts,
n_shared_experts=0,
n_activated_experts=config.top_k_experts,
Expand Down Expand Up @@ -440,8 +441,10 @@ def __init__(
for k, v in text_config.items():
setattr(cfg_text, k, v)

# Compat: checkpoints renamed expert_intermediate_size moe_intermediate_size.
# Compat: older checkpoints used expert_intermediate_size, v5.5+ uses moe_intermediate_size.
cfg_text = config.text_config if hasattr(config, "text_config") else config
if not getattr(cfg_text, "moe_intermediate_size", None) and getattr(cfg_text, "expert_intermediate_size", None):
cfg_text.moe_intermediate_size = cfg_text.expert_intermediate_size
if not getattr(cfg_text, "expert_intermediate_size", None) and getattr(cfg_text, "moe_intermediate_size", None):
cfg_text.expert_intermediate_size = cfg_text.moe_intermediate_size

Expand Down
15 changes: 8 additions & 7 deletions nemo_automodel/components/models/nemotron_parse/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ def forward(
encoder_hidden_states,
encoder_attention_mask,
None, # past_key_values
output_attentions,
False, # use_cache
)
else:
Expand All @@ -333,15 +332,17 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=None,
output_attentions=output_attentions,
use_cache=False,
)
hidden_states = layer_outputs[0]

if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if isinstance(layer_outputs, torch.Tensor):
hidden_states = layer_outputs
else:
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)

hidden_states = self.layer_norm(hidden_states)

Expand Down
5 changes: 3 additions & 2 deletions nemo_automodel/components/models/nemotron_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,9 @@ def prepare_inputs_for_generation(

batch_size = input_ids.shape[0]

# Create cache on first call
if past_key_values is None:
# Create cache on first call, or replace non-NemotronHybridCache
# (transformers v5.5+ GenerationMixin may pre-create a DynamicCache)
if past_key_values is None or not isinstance(past_key_values, NemotronHybridCache):
past_key_values = NemotronHybridCache(self.config, batch_size, self.dtype, self.device)
# First call: cache_position covers the full prompt
if cache_position is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def forward(
return super().forward(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ dependencies = [
"pyyaml",
"torch>=2.6.0,<=2.10.0",
"torchdata",
"transformers>=5.3.0,<5.4.0",
"transformers==5.5.0",
"wandb",
"torchao",
"mlflow",
Expand Down Expand Up @@ -128,7 +128,7 @@ moe = [
vlm = [
"albumentations",
"backoff",
"mistral_common[opencv]>=1.9.0",
"mistral_common[opencv]>=1.11.0",
"numpy",
"numba",
"open-clip-torch",
Expand Down
24 changes: 13 additions & 11 deletions tests/functional_tests/checkpoint/test_peft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@
import shutil
from pathlib import Path

import datasets
import torch
import torch.distributed.checkpoint as dcp
import torch.distributed.tensor
import torch.nn as nn
import yaml
from peft import PeftModel
from safetensors import safe_open
from transformers import AutoModelForImageTextToText
import yaml

from nemo_automodel.components.checkpoint._backports.hf_storage import _HuggingFaceStorageReader
from nemo_automodel.components.checkpoint.stateful_wrappers import ModelState, OptimizerState
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
from nemo_automodel.recipes.vlm.finetune import FinetuneRecipeForVLM, calculate_loss

import datasets
datasets.disable_caching()


Expand All @@ -52,11 +52,11 @@ def get_validation_loss(
with torch.no_grad():
out = model(**val_batch)
loss = calculate_loss(
loss_fn,
logits=out.logits,
labels=labels,
mask=loss_mask,
)
loss_fn,
logits=out.logits,
labels=labels,
mask=loss_mask,
)
return loss


Expand Down Expand Up @@ -95,13 +95,15 @@ def load_dcp(ckpt_dir: Path | str) -> tuple[dict, dict]:


def compare_configs(source_config: dict, restored_config: dict):
""" Recursively compare two configs."""
"""Recursively compare two configs."""
for k, v in source_config.items():
if k in restored_config:
if isinstance(v, dict):
compare_configs(v, restored_config[k])
else:
assert v == restored_config[k], f"Config mismatch for key {k}. Expected {v} but got {restored_config[k]}"
assert v == restored_config[k], (
f"Config mismatch for key {k}. Expected {v} but got {restored_config[k]}"
)


def load_safetensors(ckpt_dir: Path | str) -> dict[str, torch.Tensor]:
Expand All @@ -125,6 +127,7 @@ def to_cpu(
"""
return {k: v.cpu() for k, v in state_dict.items() if isinstance(v, torch.Tensor)}


def get_test_peft_vlm_checkpoint_expected_keys():
expected_model_keys = {
"base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.weight": (
Expand Down Expand Up @@ -953,8 +956,7 @@ def test_hf_peft_checkpoint():


def _rename_keys(d: dict, prepend: str):
"""Rename the keys of *d* by prepending *prepend* to each key.
"""
"""Rename the keys of *d* by prepending *prepend* to each key."""
flat: dict[str, torch.Tensor] = {}
for k, v in d.items():
key = f"{prepend}{k}"
Expand Down
Loading
Loading