Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ step_scheduler:

dist_env:
backend: nccl
timeout_minutes: 1
timeout_minutes: 20

rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
Expand Down Expand Up @@ -105,8 +105,11 @@ ci:
recipe_owner: akoumpa
time: "00:15:00"
checkpoint_robustness:
hf_kl_threshold: 5e-3
kl_threshold: 5e-3
hf_kl_threshold: 1e1
tokenizer_name: nvidia/Nemotron-Flash-1B
trust_remote_code: true
no_check_resume: true
dataset.limit_dataset_samples: 500
validation_dataset.limit_dataset_samples: 500

Expand Down
20 changes: 18 additions & 2 deletions nemo_automodel/_transformers/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@
if not hasattr(PretrainedConfig, "pad_token_id"):
PretrainedConfig.pad_token_id = None

from nemo_automodel._transformers.utils import apply_qwen3_omni_config_patch
from nemo_automodel._transformers.utils import _patch_layer_types_validator, apply_qwen3_omni_config_patch

apply_qwen3_omni_config_patch()
# Relax transformers v5 strict ``layer_types`` validation early so that
# remote-code configs (e.g. nvidia/Nemotron-Flash-1B) with non-canonical
# taxonomies can be loaded by ``AutoConfig`` / ``AutoTokenizer`` before the
# recipe gets a chance to call ``apply_cache_compatibility_patches``.
_patch_layer_types_validator()

import nemo_automodel.components.checkpoint.utils as checkpoint_utils
import nemo_automodel.components.distributed.utils as dist_utils
Expand Down Expand Up @@ -74,11 +79,22 @@ def _filter_meta_device_from_init_context(contexts):
return [c for c in contexts if not (isinstance(c, torch.device) and getattr(c, "type", None) == "meta")]


def _is_remote_code_class(cls) -> bool:
"""Return True if ``cls`` is a dynamically-loaded ``trust_remote_code`` class.

Such classes live under ``transformers_modules.*`` (the HF module cache)
and commonly perform ``.to(device)`` on meta tensors during ``__init__``,
which explodes when HF wraps init with ``torch.device("meta")``.
"""
mod = getattr(cls, "__module__", "") or ""
return mod.startswith("transformers_modules.") or ".transformers_modules." in mod


def _patched_get_init_context(cls, *args, **kwargs):
"""Wrapper around PreTrainedModel.get_init_context that strips meta device when requested."""
original = _patched_get_init_context.__wrapped__
contexts = original(cls, *args, **kwargs)
if _get_hf_meta_device_disabled():
if _get_hf_meta_device_disabled() or _is_remote_code_class(cls):
return _filter_meta_device_from_init_context(contexts)
return contexts

Expand Down
112 changes: 110 additions & 2 deletions nemo_automodel/_transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,105 @@ def _patched_init(self, *args, **kwargs):
PreTrainedTokenizer.__init__._nemo_stp_patched = True # type: ignore[attr-defined]


def _patch_dynamic_module_local_copy():
"""Recursively copy transitive relative imports when loading remote code from a local dir.

Transformers v5's ``get_cached_module_file`` only copies the direct relative
imports of the top-level module when the source is a local folder — it does
not recurse. Models like ``nvidia/Nemotron-Flash-1B`` whose entrypoint
module file imports a dep (``fused_mha_with_cache``) that in turn imports
another local file (``triton_attention``) end up with the transitive file
missing from the HF module cache, causing
``FileNotFoundError: .../triton_attention.py`` when the cached module is
imported. The hub branch of the same function already uses recursive
resolution; we mirror that behaviour for local folders.
"""
import os
import shutil

import transformers.dynamic_module_utils as dmu

if getattr(dmu.get_cached_module_file, "_nemo_local_recurse_patched", False):
return

_orig = dmu.get_cached_module_file

def _patched_get_cached_module_file(
pretrained_model_name_or_path, module_file, *args, **kwargs
):
result = _orig(pretrained_model_name_or_path, module_file, *args, **kwargs)
try:
pmp = str(pretrained_model_name_or_path)
if not os.path.isdir(pmp):
return result
src_file = os.path.join(pmp, module_file)
if not os.path.isfile(src_file):
return result
# Mirror transformers' own cache layout derivation.
submodule = dmu._sanitize_module_name(os.path.basename(pmp))
submodule_path = dmu.Path(dmu.HF_MODULES_CACHE) / (
dmu.TRANSFORMERS_DYNAMIC_MODULE_NAME + os.sep + submodule
)
needed = dmu.get_relative_import_files(src_file)
for nf in needed:
rel = os.path.relpath(nf, pmp)
dst = submodule_path / rel
if not os.path.isfile(nf):
continue
if dst.exists():
continue
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(nf, dst)
except Exception:
# Best-effort: if anything goes wrong, fall through to the original
# behaviour and let HF raise a clearer error downstream.
pass
return result

_patched_get_cached_module_file._nemo_local_recurse_patched = True # type: ignore[attr-defined]
dmu.get_cached_module_file = _patched_get_cached_module_file


def _patch_layer_types_validator():
"""Relax ``PreTrainedConfig.validate_layer_type`` to tolerate custom taxonomies.

Transformers v5 validates ``layer_types`` entries against a fixed whitelist
(``ALLOWED_LAYER_TYPES``). Remote-code configs (``trust_remote_code=True``)
such as ``nvidia/Nemotron-Flash-1B`` ship their own layer taxonomy (e.g.
``deltanet``, ``m2``, ``f``) which isn't in that set, so strict validation
raises at config instantiation and the model never loads.

We downgrade the value check from ``ValueError`` to a warning while keeping
the length check intact. The custom model code consumes ``config.layer_types``
directly and maps its own values to the standard taxonomy internally.
"""
from transformers import configuration_utils as cu

if getattr(cu.PreTrainedConfig.validate_layer_type, "_nemo_layer_types_patched", False):
return

def _patched_validate_layer_type(self):
lt = getattr(self, "layer_types", None)
if lt is None or not hasattr(self, "num_hidden_layers"):
return
unknown = [x for x in lt if x not in cu.ALLOWED_LAYER_TYPES]
if unknown:
logger.warning(
"layer_types contains non-canonical entries %s (allowed: %s); "
"skipping strict value validation (likely remote-code model with its own taxonomy).",
sorted(set(unknown)),
cu.ALLOWED_LAYER_TYPES,
)
if self.num_hidden_layers is not None and self.num_hidden_layers != len(lt):
raise ValueError(
f"`num_hidden_layers` ({self.num_hidden_layers}) must be equal to the number of layer types "
f"({len(lt)})"
)

_patched_validate_layer_type._nemo_layer_types_patched = True # type: ignore[attr-defined]
cu.PreTrainedConfig.validate_layer_type = _patched_validate_layer_type


def apply_cache_compatibility_patches():
"""Apply compatibility patches for transformers cache utilities.

Expand All @@ -154,6 +253,8 @@ def apply_cache_compatibility_patches():
"""
_patch_bytes_to_unicode()
_patch_special_tokens_pattern()
_patch_layer_types_validator()
_patch_dynamic_module_local_copy()

import transformers.cache_utils as cache_utils

Expand Down Expand Up @@ -225,8 +326,15 @@ def _patched_post_init(self):
source = _find_embedding_source(self)
if source is None:
raise ValueError("Could not find the source of the embedding layer")
self._nemo_tied_weights_keys = {k: source for k in tied}
self._tied_weights_keys = {}
tied_dict = {k: source for k in tied}
self._nemo_tied_weights_keys = tied_dict
# Keep the v5 dict form on the model so that any downstream HF
# code path (e.g. vanilla ``AutoModelForCausalLM.from_pretrained``
# used by the checkpoint-robustness test) ties the weights via
# HF's own ``tie_weights`` and does not leave the tied sibling
# (``lm_head.weight``) randomly initialized — which would cause
# NaN logits for tied-embedding remote-code models.
self._tied_weights_keys = tied_dict
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adil-a does this actually work or is it llm hallucination? I thought the tied_weights_keys is written in the model file, which is copied as-is, IDK if i misunderstood something.

# call orig post init
_orig_post_init(self)

Expand Down
Loading