Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ ci:
time: "00:45:00"
vllm_deploy: true
checkpoint_robustness:
hf_kl_threshold: 5e-3
hf_kl_threshold: 2.5e-2
resume_loss_threshold: 5e-2
distributed.tp_size: 8
tokenizer_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5
hf_device_map_auto: true
Expand Down
71 changes: 70 additions & 1 deletion nemo_automodel/components/checkpoint/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,21 @@ def pre_save(self, **kwargs) -> None:
config_name = "config.v5.json"

_maybe_strip_quantization_config(model_part)
with open(os.path.join(hf_metadata_dir, config_name), "w") as f:
config_path = os.path.join(hf_metadata_dir, config_name)
with open(config_path, "w") as f:
if hasattr(model_part.config, "to_json_string"):
f.write(model_part.config.to_json_string())
else:
# Diffusers models use FrozenDict for config instead of PretrainedConfig
json.dump(dict(model_part.config), f, indent=2, default=str)
if hasattr(model_part.config, "to_json_string"):
# Guarantee ``model_type`` and ``auto_map`` land in the serialized JSON.
# HF's ``PreTrainedConfig.to_json_string`` defaults to ``use_diff=True``,
# which can drop these keys for ``trust_remote_code`` configs (e.g.
# DeciLM / Llama-Nemotron-Super) — causing ``AutoConfig.from_pretrained``
# on the consolidated dir to raise ``Unrecognized model ... Should have
# a 'model_type' key``, breaking checkpoint-robustness reload (Phase 3/4).
_ensure_model_type_and_auto_map(config_path, model_part.config, original_model_path)

# save the generation_config.json file
if getattr(model_part, "generation_config", None) is not None:
Expand Down Expand Up @@ -361,6 +370,66 @@ def _extract_target_modules(model: nn.Module, v4_compatible: bool = False) -> li
return sorted(final_target_modules)


def _ensure_model_type_and_auto_map(config_path: str, config_obj, original_model_path: str | None) -> None:
"""Ensure the saved ``config.json`` has ``model_type`` and (when applicable) ``auto_map``.

Context: HF ``PreTrainedConfig.to_json_string`` defaults to ``use_diff=True`` and
may omit ``model_type`` or ``auto_map`` for ``trust_remote_code`` configs
(e.g. DeciLM / Llama-Nemotron-Super ``model_type='nemotron-nas'``). Without
``model_type`` in ``config.json``, ``AutoConfig.from_pretrained`` on the
consolidated directory raises ``Unrecognized model ... Should have a
'model_type' key``, breaking checkpoint-robustness reload (Phase 3/4). Without
``auto_map``, HF cannot locate the custom config class even when
``trust_remote_code=True`` is passed.
"""
try:
with open(config_path) as f:
config_dict = json.load(f)
except (OSError, ValueError):
config_dict = {}

# Load the original pretrained config.json once (source of truth for trust_remote_code
# models whose ``config_obj`` may have lost ``auto_map`` on the NeMo code path).
original: dict = {}
if original_model_path and os.path.isdir(original_model_path):
src = os.path.join(original_model_path, "config.json")
if os.path.isfile(src):
try:
with open(src) as f:
original = json.load(f)
except (OSError, ValueError):
original = {}

changed = False

if not config_dict.get("model_type"):
model_type = (
getattr(type(config_obj), "model_type", None)
or getattr(config_obj, "model_type", None)
or original.get("model_type")
)
if model_type:
config_dict["model_type"] = model_type
changed = True

if not config_dict.get("auto_map"):
auto_map = getattr(config_obj, "auto_map", None) or original.get("auto_map")
if auto_map:
config_dict["auto_map"] = auto_map
changed = True

# Also preserve ``architectures`` from the original if missing; some transformers
# versions drop it when serializing custom configs registered via
# ``register_for_auto_class``.
if not config_dict.get("architectures") and original.get("architectures"):
config_dict["architectures"] = original["architectures"]
changed = True

if changed:
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=2, sort_keys=True)


def _maybe_strip_quantization_config(model_part: nn.Module) -> None:
"""Remove ``quantization_config`` from the HF config when no parameters are quantized.

Expand Down
Loading