add: DFlash block diffusion speculative decoding#1211
add: DFlash block diffusion speculative decoding#1211
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR introduces comprehensive DFlash (Block Diffusion for Speculative Decoding) support to ModelOpt, including configuration layers, model conversion and plugins for HuggingFace, training infrastructure updates, export utilities, validation methods, and end-to-end recipe and launcher tooling for training and evaluation. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
f990e5a to
e45cc37
Compare
There was a problem hiding this comment.
Actionable comments posted: 17
🧹 Nitpick comments (1)
examples/speculative_decoding/doc/dflash_results.md (1)
5-85: Add reproducibility metadata alongside reported metrics.Please include the exact eval command(s), seed(s), and checkpoint artifact identifier(s) used for these tables so others can reproduce the numbers without guessing.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/doc/dflash_results.md` around lines 5 - 85, Add reproducibility metadata to dflash_results.md by appending exact evaluation commands, random seed values, and checkpoint artifact identifiers used to produce the reported tables (e.g., the Key Metrics, MT-Bench Per-Category AR, Comparison with z-lab, and Evaluation Method Impact sections). For each table/experiment (such as the gsm8k and MT-Bench runs and the ModelOpt 306K checkpoint), include the full CLI or python invocation (including flags like block_size, osl, sequence length, draft layers, anchors per sample), the seed(s) used, and the storage/registry identifiers or S3/GS/artifact names for the specific checkpoint(s) (e.g., the 306K checkpoint), plus the environment (GPU count/nodes) and any non-default preprocessing or eval-mode choices (Fixed GT vs Online GT). Place this reproducibility block near the top or directly under "Key Metrics" so readers can immediately reproduce the results.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 274-279: The call to load_dataset in the validate_ar invocation
hard-codes a local mirror path and should be made configurable or
fault-tolerant; update the validate_ar call site (where model=raw_model,
tokenizer=kwargs["processing_class"], ds=load_dataset(...),
device=next(raw_model.parameters()).device, num_samples=8) to derive the dataset
source from a new parameter or env var (e.g. mt_bench_dataset) with a default of
"HuggingFaceH4/mt_bench_prompts", and wrap the load_dataset call in a try/except
that falls back to the default Hub ID if the local path load fails. Ensure the
new parameter is plumbed from kwargs or config so single-GPU / public setups use
the Hub dataset by default.
In `@examples/speculative_decoding/main.py`:
- Around line 317-334: The fallback currently loads trainer_state.json into
trainer.state but calls trainer.train() without resume_from_checkpoint, so the
dataloader/loop restarts at step 0; update the fallback to load the JSON into
trainer.state (using trainer.state.load_from_json(state_file)) and then call
trainer.train(resume_from_checkpoint=checkpoint) so Hugging Face Trainer
receives the checkpoint and correctly resumes the dataloader and training loop;
ensure you still print the resumed step/max_steps using resumed_step and
resumed_max_steps as before.
In `@examples/speculative_decoding/README.md`:
- Around line 369-378: The README table lists
dflash.dflash_architecture_config.mask_token_id as having default "auto" but the
recipe doesn't define it; either add a mask_token_id key with value "auto" to
the dflash.yaml recipe that backs speculative decoding (so the recipe explicitly
documents the default), or update the README table entry to indicate
mask_token_id is optional/auto-inferred (e.g., mark as "auto
(optional/inferred)") so the documentation matches the actual recipe; locate the
mask_token_id entry under the dflash architecture config in the spec/config for
speculative decoding and make the corresponding change.
In `@examples/speculative_decoding/scripts/export_hf_checkpoint.py`:
- Line 41: The script calls load_vlm_or_llm(args.model_path, torch_dtype="auto")
without exposing trust_remote_code; add a CLI flag/argument (e.g.,
args.trust_remote_code defaulting to False) and pass it through to
load_vlm_or_llm as load_vlm_or_llm(args.model_path, torch_dtype="auto",
trust_remote_code=args.trust_remote_code) so callers can opt into remote code
when needed; update the argument parser to document the flag and set the default
to False.
In `@examples/speculative_decoding/train_dflash.py`:
- Around line 292-293: The code currently hardcodes
load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts") which requires an
internal mirror; change the load to accept an external dataset identifier via an
argument or environment variable and fall back to the public
"HuggingFaceH4/mt_bench_prompts" if none is provided. Update where ds is created
(the load_dataset call near validator = HFARValidation(raw_model, tokenizer)) to
read a CLI flag or os.environ key (e.g., MT_BENCH_DATASET) and pass that value
into load_dataset, defaulting to "HuggingFaceH4/mt_bench_prompts" so the script
works standalone while still allowing an internal path when supplied.
- Around line 150-153: Add a new CLI flag (e.g., --trust-remote-code) that
defaults to False (use action='store_true') and expose it as
args.trust_remote_code; then remove the hardcoded True and pass
args.trust_remote_code into both AutoModelForCausalLM.from_pretrained(...) and
AutoTokenizer.from_pretrained(...). Update any argument parsing logic where args
is created so the new flag is available to the model/tokenizer loading calls.
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 272-316: The _export_config method currently hardcodes
"torch_dtype": "bfloat16" causing a mismatch when export(dtype=...) saves
model.safetensors in a different dtype; update the export flow so export(...)
passes the chosen dtype into _export_config (add a dtype parameter to
_export_config) and have _export_config serialize that dtype value for the
"torch_dtype" field instead of the literal "bfloat16"; update all call sites
(e.g., wherever export calls _export_config) and ensure the same change is
applied for the similar block around lines 328-343 so the config.json always
matches the actual exported tensor dtype.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 83-85: The module-level assignment of _MLP_CLS, _NORM_CLS,
_ROTARY_CLS, and _rotate_half must be made instance-local so converting a second
model doesn't mutate shared globals: remove or stop relying on the top-level
assignment from _resolve_model_components("llama") and instead have modify()
store the resolved classes/func (result of _resolve_model_components) on the
converter/model instance (e.g., self._mlp_cls, self._norm_cls, self._rotary_cls,
self._rotate_half). Update apply_rotary_pos_emb(), the meta-buffer recovery
path, and any other callers that read the globals (including the locations
referenced around lines ~99-106, ~287-289, ~510-514) to read the instance
attributes instead of module globals, and ensure any factory or conversion
helpers receive the instance (or the specific classes) so they use the
per-instance types rather than the module-level variables.
- Around line 832-845: The one-off debug block guarded by self._psg_debug
accesses base_token.item() which fails for batch size B>1; update that debug
block (the code that sets self._psg_debug, selects base_outputs.hidden_states
for self.target_layer_ids, and prints seq_len / dflash_block_size /
self.mask_token_id) to either remove the prints entirely or log a batch-safe
representation of base_token (e.g., replace base_token.item() with
base_token[:,0].cpu().tolist() or base_token.reshape(-1).cpu().tolist()), and
ensure any other printed tensors (sel/th_dbg) are summarized (e.g.,
shapes/norms) to avoid per-sample indexing errors.
In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 156-157: The initializer currently calls
self._ensure_generation_tags() whenever self.answer_only_loss is true, which
swaps templates to text-only variants and breaks multimodal flows (e.g.,
VisionLanguageDataCollator) that expect message['content'] as blocks; modify the
logic in the __init__ (and the similar block around the 220-271 region) to only
perform the template rewrite for text-only collators — e.g., add a guard that
checks a multimodal flag or the collator class/type (or introduce an explicit
is_text_only property) before calling _ensure_generation_tags(), or
alternatively supply multimodal-safe fallback templates and use those when the
collator indicates multimodal input; ensure VisionLanguageDataCollator path does
not trigger the text-only swap.
- Around line 391-418: The collator is dropping samples without an assistant
turn regardless of answer-only mode; update the logic in the block using
messages/conversations (and the call to _sharegpt_to_openai_messages and
print_rank_0) so that the assistant-role existence check and skipping only run
when self.answer_only_loss is True, otherwise accept and batch the sample as-is
(i.e., append messages or converted conversations without the assistant-role
filter); keep references to the same symbols (messages, conversations, batch,
_sharegpt_to_openai_messages, print_rank_0) and ensure the existing dummy-batch
fallback remains unchanged.
- Around line 349-356: The assistant_masks are aligned to the original input_ids
but labels have been shifted with labels[..., :-1] = input_ids[..., 1:], so
shift assistant_mask by one before applying it to labels; inside the
answer_only_loss block (where assistant_mask is read from tokenized_examples)
compute a shifted mask like mask_shifted = assistant_mask[..., 1:] (or align to
labels' shape) and then set labels[..., :-1][mask_shifted == 0] =
IGNORE_TOKEN_ID (ensure tensor type and shape match, e.g., only do this when
assistant_mask is a torch.Tensor and mask_shifted.any()).
In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 111-113: The code is calling validator.validate(...) which runs
the offline HFAR path; change it to call the online validation loop by invoking
validator.validate_online(osl=32, input_ids=input_ids, steps=3) (or the correct
parameter names for validate_online) and keep extracting the AR result (e.g.,
"_, ar = validator.validate_online(...)") and appending ar to ars so the script
uses AcceptanceRateValidation.validate_online() instead of
HFARValidation.validate().
- Around line 63-66: The calls to AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained currently hardcode trust_remote_code=True; change
them to read a new environment variable (e.g., ALLOW_TRUST_REMOTE_CODE) that
defaults to false and convert it to a boolean (treat "1", "true", "yes"
case-insensitively as true). Pass that boolean into the trust_remote_code
parameter for both AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained so remote-code execution is opt-in when
HF_MODEL_CKPT is used.
In `@tools/launcher/common/dflash/online_training.sh`:
- Line 34: The pip install line currently uses an unquoted comparison operator
so the shell interprets ">" as redirection; update the package spec in the
script by quoting or escaping the version constraint (e.g., change the pip
command to use "huggingface-hub>=1.2.1" or huggi ngface-hub\>=1.2.1) so the
minimum version constraint is passed to pip rather than redirecting stdout.
- Around line 181-223: The inline python invoked via the python3 -c block
insecurely interpolates shell variables (e.g., ${DFLASH_BLOCK_SIZE},
${DFLASH_NUM_LAYERS}, ${MASK_ARG}, ${HF_MODEL_CKPT}) and hardcodes
trust_remote_code=True in AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained; fix by changing the script to read values from
environment variables or sys.argv inside the Python snippet (use os.environ or
argparse) instead of shell interpolation, validate/convert numeric values
(dflash_block_size, num_hidden_layers) there, and make trust_remote_code
configurable (read from env/default to False) before calling from_pretrained so
no untrusted remote code is loaded by default.
In `@tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml`:
- Around line 50-54: Replace the speculative algorithm flag value from EAGLE3 to
DRAFT_TARGET so VLLM treats the DFlash output as a generic draft model; update
the entry in the YAML where "--speculative_algorithm EAGLE3" appears to
"--speculative_algorithm DRAFT_TARGET" (this ensures the DFlash draft produced
by common/dflash/online_training.sh is loaded instead of mapping to the EAGLE3
backend in examples/specdec_bench/specdec_bench/models/vllm.py).
---
Nitpick comments:
In `@examples/speculative_decoding/doc/dflash_results.md`:
- Around line 5-85: Add reproducibility metadata to dflash_results.md by
appending exact evaluation commands, random seed values, and checkpoint artifact
identifiers used to produce the reported tables (e.g., the Key Metrics, MT-Bench
Per-Category AR, Comparison with z-lab, and Evaluation Method Impact sections).
For each table/experiment (such as the gsm8k and MT-Bench runs and the ModelOpt
306K checkpoint), include the full CLI or python invocation (including flags
like block_size, osl, sequence length, draft layers, anchors per sample), the
seed(s) used, and the storage/registry identifiers or S3/GS/artifact names for
the specific checkpoint(s) (e.g., the 306K checkpoint), plus the environment
(GPU count/nodes) and any non-default preprocessing or eval-mode choices (Fixed
GT vs Online GT). Place this reproducibility block near the top or directly
under "Key Metrics" so readers can immediately reproduce the results.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 57539396-3987-4d1b-9a04-2d7606707bb5
📒 Files selected for processing (24)
doc/results/dflash_results.htmlexamples/speculative_decoding/README.mdexamples/speculative_decoding/doc/dflash_results.mdexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pyexamples/speculative_decoding/train_dflash.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/__init__.pymodelopt/torch/speculative/dflash/conversion.pymodelopt/torch/speculative/dflash/default_config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/mode.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/utils.pymodelopt/torch/utils/plugins/transformers_dataset.pymodelopt_recipes/general/speculative_decoding/dflash.yamltests/gpu/torch/speculative/plugins/test_hf_dflash.pytests/unit/torch/speculative/plugins/test_hf_dflash.pytools/launcher/common/dflash/ar_validate.shtools/launcher/common/dflash/online_training.shtools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
| try: | ||
| trainer.train(resume_from_checkpoint=checkpoint) | ||
| except ValueError as e: | ||
| if "parameter group" in str(e): | ||
| print_rank_0( | ||
| f"Optimizer state mismatch: {e}\n" | ||
| f"Resuming with fresh optimizer from {checkpoint}" | ||
| ) | ||
| state_file = os.path.join(checkpoint, "trainer_state.json") | ||
| if os.path.isfile(state_file): | ||
| state = json.load(open(state_file)) | ||
| resumed_step = state.get("global_step", 0) | ||
| resumed_max_steps = state.get("max_steps", -1) | ||
| print_rank_0(f"Resuming from step {resumed_step}/{resumed_max_steps}") | ||
| if resumed_max_steps > 0: | ||
| training_args.max_steps = resumed_max_steps | ||
| trainer.state = trainer.state.load_from_json(state_file) | ||
| trainer.train() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /repo && find . -name "main.py" -path "*/speculative_decoding/*" -type fRepository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
cd /repo && cat -n examples/speculative_decoding/main.py | sed -n '310,340p'Repository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
cd /repo && wc -l examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
pwd && ls -laRepository: NVIDIA/Model-Optimizer
Length of output: 2199
🏁 Script executed:
find . -name "main.py" -path "*/speculative_decoding/*" -type f 2>/dev/nullRepository: NVIDIA/Model-Optimizer
Length of output: 105
🏁 Script executed:
if [ -f "examples/speculative_decoding/main.py" ]; then wc -l examples/speculative_decoding/main.py; fiRepository: NVIDIA/Model-Optimizer
Length of output: 107
🏁 Script executed:
if [ -f "examples/speculative_decoding/main.py" ]; then cat -n examples/speculative_decoding/main.py | sed -n '310,340p'; fiRepository: NVIDIA/Model-Optimizer
Length of output: 1764
🏁 Script executed:
web_search
HuggingFace Transformers Trainer.train() resume_from_checkpoint state initialization behaviorRepository: NVIDIA/Model-Optimizer
Length of output: 295
🏁 Script executed:
rg -A 10 "trainer\.train\(" examples/ --type py | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 3879
🏁 Script executed:
rg "trainer\.state\s*=" examples/ --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 496
🏁 Script executed:
grep -r "load_from_json" examples/ --include="*.py" -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 496
🌐 Web query:
HuggingFace Transformers Trainer.train() resume_from_checkpoint parameter state initialization
💡 Result:
In 🤗 Transformers, Trainer.train(resume_from_checkpoint=...) resumes by loading saved state on top of freshly created objects:
-
What the parameter means
- If
resume_from_checkpointis a path, it resumes from that checkpoint folder. - If it is
True, it finds and uses the last checkpoint inargs.output_dir. - When provided, training resumes from the model + optimizer + LR scheduler states found in that checkpoint. [1]
- If
-
What gets (re)initialized vs loaded
- Model weights
- The model is first present in memory (from your
model=ormodel_init), then Trainer loads the checkpoint weights into it (typically frompytorch_model.bin/model.safetensors, depending on setup/version). In the Trainer implementation you can see it loadingWEIGHTS_NAMEand callingmodel.load_state_dict(..., strict=False). [2]
- The model is first present in memory (from your
- Trainer state (
self.state)- If
{checkpoint}/trainer_state.jsonexists, Trainer loads it intoself.stateand uses it to compute how much training was already done (e.g., derivesepochs_trainedfromglobal_step). [2]
- If
- Optimizer + LR scheduler state
- The optimizer and scheduler objects are created as usual (i.e., “fresh”), and then Trainer calls
_load_optimizer_and_scheduler(checkpoint). - If
{checkpoint}/optimizer.ptand{checkpoint}/scheduler.ptexist (and you are not using DeepSpeed), Trainer loads those state dicts into the already-created optimizer/scheduler viaload_state_dict. [2] - If you are using DeepSpeed, Transformers’ Trainer code path skips this because DeepSpeed handles resume as part of its own init/checkpoint logic. [2]
- The optimizer and scheduler objects are created as usual (i.e., “fresh”), and then Trainer calls
- RNG state
- Checkpoints also save RNG state (e.g.,
rng_state.pth/ rank-specific variants), so resuming can continue deterministically when those files are present. [2]
- Checkpoints also save RNG state (e.g.,
- Model weights
-
Common implication
- “Resume” is not “don’t initialize anything”; it’s “initialize normal training objects, then overwrite their internal state from checkpoint files if they exist.” This is why mismatches can happen if your optimizer parameter groups differ between runs (e.g., changed trainable params, LoRA targets, etc.), causing
optimizer.load_state_dict(...)to fail.
- “Resume” is not “don’t initialize anything”; it’s “initialize normal training objects, then overwrite their internal state from checkpoint files if they exist.” This is why mismatches can happen if your optimizer parameter groups differ between runs (e.g., changed trainable params, LoRA targets, etc.), causing
Sources: Hugging Face Trainer docs for resume_from_checkpoint [1] and Trainer implementation showing loading of trainer_state.json, optimizer/scheduler state, and RNG save/load logic [2].
[1] (huggingface.co)
[2] (huggingface.co)
Citations:
- 1: https://huggingface.co/docs/transformers/main_classes/trainer
- 2: https://huggingface.co/transformers/v4.6.0/_modules/transformers/trainer.html
🏁 Script executed:
rg "def train" examples/speculative_decoding/main.py -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 211
🏁 Script executed:
# Check what imports are in the file
head -30 examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1490
🏁 Script executed:
# Look for any trainer initialization or state handling patterns
rg "Trainer\(" examples/speculative_decoding/main.py -B 5 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check if there's any documentation or comments about state resumption in similar files
rg "resume" examples/llm_qat/main.py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 520
🏁 Script executed:
# Check the actual behavior by looking at what the code tries to accomplish
cat -n examples/speculative_decoding/main.py | sed -n '314,340p'Repository: NVIDIA/Model-Optimizer
Length of output: 1581
🏁 Script executed:
# Search for any notes or context about why the manual state loading is done
rg "Optimizer state mismatch" . --type py -B 5 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 1543
🌐 Web query:
HuggingFace Trainer trainer.state initialization without resume_from_checkpoint parameter fresh initialization
💡 Result:
In 🤗 Transformers, a fresh (non-resumed) trainer.state comes from the default TrainerState created when you instantiate Trainer, and it is only replaced if you pass (or propagate) a resume_from_checkpoint and a trainer_state.json is found.
What happens on a fresh run (no resume)
-
Trainer.__init__creates a new stateself.state = TrainerState( is_local_process_zero=..., is_world_process_zero=..., stateful_callbacks=[...], )
(see
Trainer.__init__) [1] -
Default
TrainerStatevalues
TrainerStatedefaults includeglobal_step=0,max_steps=0, etc. [2] -
At training start, epoch is set to 0
Inside_inner_training_loop, Trainer sets:self.state.epoch = 0
and only then checks whether it should load a checkpoint state. [3]
When/why it stops being “fresh”
- In
Trainer.train(...), if you passresume_from_checkpoint=True, it will resolve the last checkpoint inoutput_dirviaget_last_checkpoint(...). [4] - In
_inner_training_loop, ifresume_from_checkpointpoints to a checkpoint and that folder containstrainer_state.json, Trainer does:[3]self.state = TrainerState.load_from_json(.../trainer_state.json)
How to force “fresh initialization” in practice
- Make sure you are not doing any of these:
trainer.train(resume_from_checkpoint=True)(auto-picks last checkpoint) [4]trainer.train(resume_from_checkpoint=".../checkpoint-XXXX")[3]- Passing
TrainingArguments(resume_from_checkpoint=...)through an example script/CLI (many scripts forward it intotrainer.train(...)) [5]
A robust pattern:
training_args = TrainingArguments(..., resume_from_checkpoint=None)
trainer = Trainer(..., args=training_args)
trainer.train(resume_from_checkpoint=None) # or just trainer.train()If you still see resume behavior, it’s almost always because some wrapper/script is passing resume_from_checkpoint=True/path under the hood.
Sources:
[1] Trainer initializes self.state = TrainerState(...) in __init__ (Transformers trainer.py). (github.com)
[2] TrainerState default fields (HF docs: Callbacks → TrainerState). (huggingface.co)
[3] _inner_training_loop sets epoch=0 and loads TrainerState from trainer_state.json only when resuming. (github.com)
[4] Trainer.train: if resume_from_checkpoint is True, it resolves get_last_checkpoint(output_dir). (github.com)
[5] TrainingArguments includes a resume_from_checkpoint field (can be propagated by scripts). (huggingface.co)
Citations:
- 1: https://github.com/huggingface/transformers/blob/b7672826cad31e30319487af876e608d8af7d37b/src/transformers/trainer.py
- 2: https://huggingface.co/docs/transformers/en/main_classes/callback?utm_source=openai
- 3: https://github.com/huggingface/transformers/raw/b7672826cad31e30319487af876e608d8af7d37b/src/transformers/trainer.py
- 4: https://github.com/huggingface/transformers/raw/b7672826cad31e30319487af876e608d8af7d37b/src/transformers/trainer.py
- 5: https://huggingface.co/docs/transformers/en/main_classes/trainer?utm_source=openai
🏁 Script executed:
# Check if there's any setup or initialization of trainer.state before the resume logic
rg "trainer\.state" examples/speculative_decoding/main.py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 419
🏁 Script executed:
# Look for the trainer initialization to understand the baseline state
rg "Trainer\(" examples/speculative_decoding/main.py -B 10 | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 48
The fallback doesn't properly resume training—it restarts the dataloader from step 0.
Loading trainer_state.json into trainer.state before trainer.train() without resume_from_checkpoint is insufficient. In Hugging Face Trainer, the dataloader position and training loop resumption depend on passing the resume_from_checkpoint parameter to trainer.train(). Without it, the dataloader begins at step 0 despite the manually-loaded state values. The printed "Resuming from step X" message is misleading—training actually replays from the start, defeating the purpose of the fallback recovery path (lines 315–316).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/main.py` around lines 317 - 334, The fallback
currently loads trainer_state.json into trainer.state but calls trainer.train()
without resume_from_checkpoint, so the dataloader/loop restarts at step 0;
update the fallback to load the JSON into trainer.state (using
trainer.state.load_from_json(state_file)) and then call
trainer.train(resume_from_checkpoint=checkpoint) so Hugging Face Trainer
receives the checkpoint and correctly resumes the dataloader and training loop;
ensure you still print the resumed step/max_steps using resumed_step and
resumed_max_steps as before.
| model = load_vlm_or_llm( | ||
| args.model_path, torch_dtype="auto", trust_remote_code=args.trust_remote_code | ||
| ) | ||
| model = load_vlm_or_llm(args.model_path, torch_dtype="auto") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Check exporter CLI and model load call:"
rg -n -C2 'add_argument\("--trust_remote_code"|load_vlm_or_llm\(' examples/speculative_decoding/scripts/export_hf_checkpoint.py
echo
echo "Check loader API supports trust_remote_code and default value:"
rg -n -C2 'def load_vlm_or_llm\(|trust_remote_code' modelopt/torch/speculative/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1457
🏁 Script executed:
cat -n examples/speculative_decoding/scripts/export_hf_checkpoint.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2038
Expose caller-controlled trust_remote_code parameter to allow exporting models requiring custom code.
The exporter script does not expose control for the trust_remote_code parameter, preventing users from loading models that require custom remote code during export. The underlying load_vlm_or_llm() function already accepts and properly propagates this parameter with a safe default of False; add CLI exposure to let users opt in when needed.
Suggested patch
def parse_args():
parser = argparse.ArgumentParser(
description="Export a HF checkpoint (with ModelOpt state) for deployment."
)
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
parser.add_argument(
"--export_path", type=str, default="Destination directory for exported files."
)
+ parser.add_argument(
+ "--trust_remote_code",
+ action="store_true",
+ help="Allow loading custom remote code from model repos (default: False).",
+ )
return parser.parse_args()
mto.enable_huggingface_checkpointing()
args = parse_args()
-model = load_vlm_or_llm(args.model_path, torch_dtype="auto")
+model = load_vlm_or_llm(
+ args.model_path,
+ torch_dtype="auto",
+ trust_remote_code=args.trust_remote_code,
+)This follows the guideline to "let the caller decide via a parameter; default to False" for trust_remote_code.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| model = load_vlm_or_llm(args.model_path, torch_dtype="auto") | |
| model = load_vlm_or_llm( | |
| args.model_path, | |
| torch_dtype="auto", | |
| trust_remote_code=args.trust_remote_code, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/scripts/export_hf_checkpoint.py` at line 41,
The script calls load_vlm_or_llm(args.model_path, torch_dtype="auto") without
exposing trust_remote_code; add a CLI flag/argument (e.g.,
args.trust_remote_code defaulting to False) and pass it through to
load_vlm_or_llm as load_vlm_or_llm(args.model_path, torch_dtype="auto",
trust_remote_code=args.trust_remote_code) so callers can opt into remote code
when needed; update the argument parser to document the flag and set the default
to False.
| model = AutoModelForCausalLM.from_pretrained( | ||
| args.model, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
fd train_dflash.pyRepository: NVIDIA/Model-Optimizer
Length of output: 111
🏁 Script executed:
cat -n examples/speculative_decoding/train_dflash.py | sed -n '140,165p'Repository: NVIDIA/Model-Optimizer
Length of output: 1291
🏁 Script executed:
# Check for --model argument definition
grep -n "model" examples/speculative_decoding/train_dflash.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 1203
🏁 Script executed:
# Check if there are any comments or configurations around the trust_remote_code lines
cat -n examples/speculative_decoding/train_dflash.py | sed -n '145,158p'Repository: NVIDIA/Model-Optimizer
Length of output: 724
Make trust_remote_code opt-in via CLI flag.
Lines 150–153 hardcode trust_remote_code=True for both model and tokenizer loads. Since --model is caller-controlled, this forces execution of arbitrary Python from any untrusted checkpoint during startup. Add a CLI flag (e.g., --trust-remote-code) that defaults to False and thread it through both from_pretrained() calls.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/train_dflash.py` around lines 150 - 153, Add a
new CLI flag (e.g., --trust-remote-code) that defaults to False (use
action='store_true') and expose it as args.trust_remote_code; then remove the
hardcoded True and pass args.trust_remote_code into both
AutoModelForCausalLM.from_pretrained(...) and
AutoTokenizer.from_pretrained(...). Update any argument parsing logic where args
is created so the new flag is available to the model/tokenizer loading calls.
| model = AutoModelForCausalLM.from_pretrained( | ||
| '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find the file
fd ar_validate.shRepository: NVIDIA/Model-Optimizer
Length of output: 109
🏁 Script executed:
# Read the entire file to see the context
cat -n tools/launcher/common/dflash/ar_validate.sh | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 4475
🏁 Script executed:
# Read the rest of the file
cat -n tools/launcher/common/dflash/ar_validate.sh | tail -50Repository: NVIDIA/Model-Optimizer
Length of output: 2292
Make remote-code execution opt-in; do not hardcode trust_remote_code=True.
Lines 63-66 hardcode trust_remote_code=True for both model and tokenizer loading. Since HF_MODEL_CKPT is environment-controlled, this allows any untrusted model checkpoint to execute arbitrary Python code during loading. Add an environment variable to control this flag and default it to False.
🔒 Proposed fix
import torch
+import os
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from modelopt.torch.speculative.plugins.transformers import HFARValidation
import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
mto.enable_huggingface_checkpointing()
+trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "0") == "1"
model = AutoModelForCausalLM.from_pretrained(
- '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True
+ '${HF_MODEL_CKPT}',
+ torch_dtype=torch.bfloat16,
+ device_map={'': 0},
+ trust_remote_code=trust_remote_code,
)
-tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True)
+tokenizer = AutoTokenizer.from_pretrained(
+ '${HF_MODEL_CKPT}', trust_remote_code=trust_remote_code
+)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tools/launcher/common/dflash/ar_validate.sh` around lines 63 - 66, The calls
to AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained
currently hardcode trust_remote_code=True; change them to read a new environment
variable (e.g., ALLOW_TRUST_REMOTE_CODE) that defaults to false and convert it
to a boolean (treat "1", "true", "yes" case-insensitively as true). Pass that
boolean into the trust_remote_code parameter for both
AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained so
remote-code execution is opt-in when HF_MODEL_CKPT is used.
| try: | ||
| _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) | ||
| ars.append(ar) |
There was a problem hiding this comment.
Call the online AR path here.
This still invokes HFARValidation.validate(), so the script computes AR against fixed ground truth instead of the new context-dependent verification loop in AcceptanceRateValidation.validate_online(). That means the reported DFlash metric does not match the new online-validation path added in this PR.
♻️ Proposed fix
- _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3)
+ _, ar = validator.validate_online(osl=32, input_ids=input_ids, steps=3)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) | |
| ars.append(ar) | |
| try: | |
| _, ar = validator.validate_online(osl=32, input_ids=input_ids, steps=3) | |
| ars.append(ar) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tools/launcher/common/dflash/ar_validate.sh` around lines 111 - 113, The code
is calling validator.validate(...) which runs the offline HFAR path; change it
to call the online validation loop by invoking validator.validate_online(osl=32,
input_ids=input_ids, steps=3) (or the correct parameter names for
validate_online) and keep extracting the AR result (e.g., "_, ar =
validator.validate_online(...)") and appending ar to ars so the script uses
AcceptanceRateValidation.validate_online() instead of HFARValidation.validate().
| - --engine VLLM | ||
| - --tp_size 4 | ||
| - --ep_size 1 | ||
| - --speculative_algorithm EAGLE3 | ||
| - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl |
There was a problem hiding this comment.
Use DRAFT_TARGET for the DFlash benchmark.
examples/specdec_bench/specdec_bench/models/vllm.py maps EAGLE3 to the eagle3 backend, and there is no DFlash handler on that path. This task will benchmark the wrong speculative-decoding implementation for the draft produced by common/dflash/online_training.sh. Switch this to DRAFT_TARGET so VLLM loads it as a generic draft model instead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml` around lines 50
- 54, Replace the speculative algorithm flag value from EAGLE3 to DRAFT_TARGET
so VLLM treats the DFlash output as a generic draft model; update the entry in
the YAML where "--speculative_algorithm EAGLE3" appears to
"--speculative_algorithm DRAFT_TARGET" (this ensures the DFlash draft produced
by common/dflash/online_training.sh is loaded instead of mapping to the EAGLE3
backend in examples/specdec_bench/specdec_bench/models/vllm.py).
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (13)
tools/launcher/common/dflash/ar_validate.sh (2)
111-113:⚠️ Potential issue | 🟠 MajorUse the online AR path here.
HFARValidation.validate()still measures against fixed ground truth, so this script reports a different metric than the new context-dependent DFlash validation path. Callvalidate_online(...)in this loop instead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/ar_validate.sh` around lines 111 - 113, Replace the offline validation call to validator.validate(...) with the online/context-dependent path by invoking validator.validate_online(...) in the loop; keep the same arguments (e.g., osl=32, input_ids=input_ids, steps=3) and continue to append the returned ar to ars so the script measures DFlash’s context-dependent AR instead of comparing to fixed ground truth.
63-66:⚠️ Potential issue | 🔴 CriticalMake remote-code execution opt-in.
HF_MODEL_CKPTis environment-controlled, but both loads hardcodetrust_remote_code=True. That lets an arbitrary checkpoint execute Python during validation. Thread this through an env flag that defaults toFalse. As per coding guidelines, "Never hardcodetrust_remote_code=True; remote-code execution is an RCE vector."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/ar_validate.sh` around lines 63 - 66, The code currently passes trust_remote_code=True into AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained for HF_MODEL_CKPT; change this to be controlled by a new environment flag (e.g., TRUST_REMOTE_CODE or ENABLE_TRUST_REMOTE_CODE) that defaults to False, parse it as a boolean, and pass that variable into the trust_remote_code parameter of both AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained so remote-code execution is opt-in.examples/speculative_decoding/train_dflash.py (2)
150-153:⚠️ Potential issue | 🔴 CriticalMake
trust_remote_codeopt-in via CLI.
--modelis caller-controlled, but bothfrom_pretrained()calls hardcodetrust_remote_code=True, which executes arbitrary checkpoint code during startup. Add a flag like--trust-remote-codedefaulting toFalseand thread it through both loads. As per coding guidelines, "Never hardcodetrust_remote_code=True; remote-code execution is an RCE vector."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/train_dflash.py` around lines 150 - 153, Add a new CLI boolean flag (e.g., args.trust_remote_code defaulting to False) and pass it to both AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained instead of hardcoding trust_remote_code=True; update the argument parser where args.model is defined to include "--trust-remote-code" (store_true) and then change the two calls (AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained) to use trust_remote_code=args.trust_remote_code so remote-code execution remains opt-in.
292-293:⚠️ Potential issue | 🟠 MajorDon't hard-require an internal MT-Bench mirror in a standalone script.
The post-train AR check now depends on
/hf-local/HuggingFaceH4/mt_bench_prompts, so this entrypoint fails outside the internal environment. Accept the dataset source via CLI/env and default toHuggingFaceH4/mt_bench_prompts.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/train_dflash.py` around lines 292 - 293, The script currently hardcodes the MT-Bench dataset path when creating ds with load_dataset; make it configurable via a CLI flag and env var with a sensible default: add an argparse option (e.g. --mtbench-dataset) that falls back to os.environ.get("MT_BENCH_DATASET") and then to the default string "HuggingFaceH4/mt_bench_prompts"; replace the literal in the load_dataset call (the line that sets ds = load_dataset(... )["train"]) to use that resolved dataset identifier so HFARValidation(raw_model, tokenizer) continues to run against the user-provided or default dataset.modelopt/torch/utils/plugins/transformers_dataset.py (3)
391-418:⚠️ Potential issue | 🟠 MajorOnly skip assistant-free chats when
answer_only_loss=True.This branch now drops prompt-only/system-user chats even in the normal full-loss path. If an entire batch is filtered, the dummy user/assistant sample becomes synthetic training data instead of a no-op. Keep the assistant-turn check behind
self.answer_only_loss.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 391 - 418, The current filtering drops samples without an assistant turn regardless of loss mode; only skip samples when self.answer_only_loss is True. Modify the branch that checks for assistant turns (the blocks handling messages and conversations and calling _sharegpt_to_openai_messages) to perform the "no assistant turn -> print warning and continue" only when self.answer_only_loss is True; otherwise append the messages/conversations as-is. Ensure references: messages, conversations, converted (from _sharegpt_to_openai_messages), batch, and self.answer_only_loss are used to gate the skip logic so dummy batch creation behavior only occurs when answer_only_loss is enabled.
156-157:⚠️ Potential issue | 🟠 MajorStill unresolved: answer-only template rewriting breaks multimodal collators.
VisionLanguageDataCollatorreaches this initializer too, but the fallback templates here explicitly drop VLM content handling and treatmessage["content"]as plain text. Withanswer_only_loss=True, multimodal batches still get rewritten to a text-only template beforeprocessor.apply_chat_template(...)runs, so list-of-block content will be misformatted or fail.Also applies to: 177-215, 441-448
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 156 - 157, The answer-only template rewrite in the constructor (triggered by answer_only_loss calling _ensure_generation_tags) incorrectly converts multimodal messages into text-only content, breaking VisionLanguageDataCollator and processor.apply_chat_template; fix by detecting multimodal inputs (e.g., presence of non-text keys or list-of-blocks/image/block structures in message["content"] used by VisionLanguageDataCollator) and skip the answer-only template rewrite for those cases, i.e., in _ensure_generation_tags (or the place where answer_only_loss is handled) add a guard that returns early when message content is multimodal so processor.apply_chat_template receives the original multimodal structure. Ensure the check references the same shapes used by VisionLanguageDataCollator (e.g., 'blocks' or image fields) so multimodal batches are preserved.
350-356:⚠️ Potential issue | 🟠 MajorShift
assistant_masksinto label space before masking.
labels[..., :-1]already contains next-token targets, butassistant_masksis aligned to the original token positions. Applying it unshifted drops the first assistant target token and keeps one token past each assistant span.Suggested fix
if self.answer_only_loss: if "assistant_masks" in tokenized_examples: assistant_mask = tokenized_examples["assistant_masks"] if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): - labels[assistant_mask == 0] = IGNORE_TOKEN_ID + shifted_assistant_mask = torch.zeros_like(assistant_mask) + shifted_assistant_mask[..., :-1] = assistant_mask[..., 1:] + labels[shifted_assistant_mask == 0] = IGNORE_TOKEN_ID else:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 350 - 356, The assistant_mask is aligned to original input positions but labels are next-token targets (labels[..., :-1] = input_ids[..., 1:]), so you must shift assistant_masks into label-space before masking; compute a shifted mask (e.g., shifted = assistant_mask[..., 1:]) and apply it to labels[..., :-1] (or equivalently mask labels where shifted == 0) instead of using the unshifted assistant_mask; keep the existing checks for torch.Tensor and .any() and set masked positions to IGNORE_TOKEN_ID when answer_only_loss is true.modelopt/torch/export/plugins/hf_spec_export.py (1)
272-316:⚠️ Potential issue | 🟠 MajorKeep
config.jsondtype aligned with the exported tensors.
export(dtype=...)castsmodel.safetensors, but_export_config()still hardcodes"torch_dtype": "bfloat16". Exporting fp16/fp32 will advertise the wrong dtype to downstream loaders.Suggested fix
- def _export_config(self): + def _export_config(self, dtype: torch.dtype | None = None): """Build config.json matching z-lab DFlash format.""" @@ - "torch_dtype": "bfloat16", + "torch_dtype": ( + str(dtype).replace("torch.", "") + if dtype is not None + else str( + getattr( + draft_config, + "torch_dtype", + getattr(base_config, "torch_dtype", torch.bfloat16), + ) + ).replace("torch.", "") + ), @@ - drafter_config = self._export_config() + drafter_config = self._export_config(dtype=dtype)Also applies to: 328-343
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 272 - 316, The config currently hardcodes "torch_dtype": "bfloat16" in _export_config(), causing a mismatch when export(dtype=...) is used; update _export_config to determine the dtype dynamically (prefer the explicit export dtype passed to export, e.g., self.export_dtype or self.dtype if present) and set "torch_dtype" to that dtype's string (falling back to the current bfloat16 value if no export dtype is available); apply the same change to the other config block around lines 328-343 so the advertised dtype matches the exported safetensors.examples/speculative_decoding/main.py (1)
311-334:⚠️ Potential issue | 🟠 MajorThe optimizer-mismatch fallback still doesn't actually resume.
Loading
trainer.stateand then callingtrainer.train()restarts the dataloader at step 0. This path will replay already-seen batches while loggingResuming from step ..., so it is not a safe recovery path for optimizer state mismatches.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 311 - 334, The fallback path sets trainer.state from trainer_state.json but then calls trainer.train() with no resume flag, which restarts from step 0 and replays batches; change the fallback to call trainer.train(resume_from_checkpoint=checkpoint) so the Trainer resumes at the saved step, and also clear any existing optimizer (e.g., set trainer.optimizer = None) before training to ensure a fresh optimizer when you intentionally fall back from an optimizer-state mismatch; keep the existing use of trainer.state.load_from_json(state_file) and the checkpoint/state_file variables.tools/launcher/common/dflash/online_training.sh (2)
30-35:⚠️ Potential issue | 🟡 MinorQuote shell variables and version specifiers.
Line 31:
${SCRIPT_DIR}should be double-quoted to prevent globbing/word splitting.Line 34: The unquoted
>=is parsed as shell redirection, causing the version constraint to be lost.🔧 Proposed fix
-source ${SCRIPT_DIR}/../service_utils.sh +source "${SCRIPT_DIR}/../service_utils.sh" pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt -pip install huggingface-hub>=1.2.1 +pip install "huggingface-hub>=1.2.1"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/online_training.sh` around lines 30 - 35, The script uses unquoted shell variables and an unquoted version specifier which can cause globbing/word-splitting and shell redirection: quote ${SCRIPT_DIR} when sourcing service_utils.sh (refer to SCRIPT_DIR and service_utils.sh) and quote the pip requirement string so the shell doesn't treat >= as a redirection (refer to the pip install command for huggingface-hub), and also quote PATH expansions when exporting (refer to the export PATH line) to avoid word-splitting.
181-223:⚠️ Potential issue | 🔴 CriticalHardcoded
trust_remote_code=Trueand shell variable interpolation pose security risks.Lines 192 and 194 hardcode
trust_remote_code=Truefor model and tokenizer loading with no override capability. Per coding guidelines, this should be configurable and default toFalse.Additionally, shell variables (
${HF_MODEL_CKPT},${DFLASH_BLOCK_SIZE},${DFLASH_NUM_LAYERS},${MASK_ARG}) are interpolated directly into the Python heredoc. Maliciously crafted values could inject arbitrary code.Pass values via environment variables and read them with
os.environ.get()inside Python:import os model_path = os.environ.get("HF_MODEL_CKPT") trust_remote = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true" # ... model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=trust_remote )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/online_training.sh` around lines 181 - 223, The heredoc currently hardcodes trust_remote_code=True and interpolates shell variables directly, which is unsafe; change the script to pass HF_MODEL_CKPT, DFLASH_BLOCK_SIZE, DFLASH_NUM_LAYERS, MASK_ARG, AR_CKPT and a new TRUST_REMOTE_CODE via environment variables and read them inside Python with os.environ.get() (casting block/num to int and TRUST_REMOTE_CODE to a boolean defaulting to False) before calling AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained; remove all ${...} interpolations from the Python snippet, use the env-derived variables when building the dflash config and when loading checkpoints (model.load_state_dict and model.dflash_module.load_state_dict), and ensure a safe default trust_remote_code=False unless the env explicitly sets it to "true".modelopt/torch/speculative/plugins/hf_dflash.py (2)
83-84:⚠️ Potential issue | 🟠 MajorModule-level globals create cross-model interference.
_MLP_CLS,_NORM_CLS,_ROTARY_CLS, and_rotate_halfare module-scope variables thatmodify()(lines 510-514) reassigns. If two DFlash models with different base architectures are instantiated in the same process, the secondmodify()call will overwrite the globals, potentially corrupting the first model's behavior at runtime (e.g., inapply_rotary_pos_emb,DFlashModule.__init__).Store these resolved components on the model instance (
self._mlp_cls, etc.) and thread them through toDFlashModuleandDFlashAttention.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 83 - 84, The module-level globals _MLP_CLS, _NORM_CLS, _ROTARY_CLS, and _rotate_half cause cross-model interference; change the code to resolve these per-model and store them on the model instance (e.g., self._mlp_cls, self._norm_cls, self._rotary_cls, self._rotate_half) inside modify()/convert() instead of reassigning the module globals. Update DFlashModule.__init__ to accept or read instance attributes (self._mlp_cls, self._norm_cls, etc.) and pass them into DFlashAttention and apply_rotary_pos_emb so those functions/classes no longer reference module-level names; ensure all call sites (modify/convert, DFlashModule creation, DFlashAttention usage, and apply_rotary_pos_emb) are updated to thread the instance-specific components through.
832-845:⚠️ Potential issue | 🟠 Major
base_token.item()fails for batch size > 1.Line 841 calls
.item()onbase_tokenwhich has shape[B, 1]. For batched inputs withB > 1, this raises aRuntimeError. Either remove the debug block or use a batch-safe representation:- print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") + print(f"[psg] base_token: {base_token[:, 0].tolist()}, mask_token_id: {self.mask_token_id}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 832 - 845, The debug block uses base_token.item() which fails for batch size >1; update the debug printing in hf_dflash.py (the block that sets self._psg_debug and prints using base_token and mask_token_id) to use a batch-safe representation instead of .item() — e.g., convert the tensor to CPU, detach it and render as a list (base_token.detach().cpu().view(-1).tolist() or base_token.detach().cpu()[:,0].tolist()) or print only the first batch element (base_token[0].item()) so it won't raise for B>1; keep other debug prints (th_dbg, seq_len, dflash_block_size, target_layer_ids) unchanged.
🧹 Nitpick comments (3)
tools/launcher/common/dflash/online_training.sh (2)
149-152: Quote path variables.
OUTPUT_DIRandEXPORT_DIRshould be double-quoted to handle paths containing spaces.🔧 Proposed fix
python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ - --model_path ${OUTPUT_DIR} \ - --export_path ${EXPORT_DIR} \ + --model_path "${OUTPUT_DIR}" \ + --export_path "${EXPORT_DIR}" \ || echo "WARNING: Export failed, continuing with AR validation"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/online_training.sh` around lines 149 - 152, The shell command invoking export_hf_checkpoint uses unquoted path variables which will break on spaces; update the invocation that references OUTPUT_DIR and EXPORT_DIR (the python3 ... --model_path ${OUTPUT_DIR} --export_path ${EXPORT_DIR} || ...) to wrap both variables in double quotes (e.g. --model_path "${OUTPUT_DIR}" --export_path "${EXPORT_DIR}") so paths with spaces are handled correctly and the fallback echo behavior remains unchanged.
124-128: Quote variables where appropriate.
CONFIG_FILEandNUM_NODESshould be quoted to guard against paths with spaces or other edge cases.OVERRIDESis intentionally unquoted for word splitting, but consider using an array for safer argument passing.🛠️ Minimal fix
bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ - --config ${CONFIG_FILE} \ - --num_nodes ${NUM_NODES:-1} \ + --config "${CONFIG_FILE}" \ + --num_nodes "${NUM_NODES:-1}" \ --head_node_ip ${HEAD_NODE_IP:-} \ ${OVERRIDES}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/online_training.sh` around lines 124 - 128, The command invocation to launch_train.sh should quote CONFIG_FILE and NUM_NODES to protect against spaces and edge cases: update the call to use "--config \"${CONFIG_FILE}\"" and "--num_nodes \"${NUM_NODES:-1}\""; keep OVERRIDES unquoted for word-splitting but preferably refactor OVERRIDES into an array (e.g., OVERRIDES_ARGS) and expand it safely (e.g., "${OVERRIDES_ARGS[@]}") when invoking launch_train.sh so arguments are passed reliably; adjust the invocation around launch_train.sh and the variables CONFIG_FILE, NUM_NODES, and OVERRIDES accordingly.modelopt/torch/speculative/plugins/hf_dflash.py (1)
499-499: Consider using logging instead of print statements.Lines 499 and 548 emit unconditional
print()calls. In production, these add noise to output. Use Python'sloggingmodule or gate behind a debug flag.🔧 Example using logging
+import logging + +_logger = logging.getLogger(__name__) + # In modify(): - print(f"DFlash mask_token_id: {self.mask_token_id}") + _logger.info(f"DFlash mask_token_id: {self.mask_token_id}") # ... - print(f"DFlash: using {original_cls.__name__}.forward as base forward") + _logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")Also applies to: 548-548
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 499, Replace the unconditional print calls that output the mask token id with Python logging: create/get a module/class logger (eg. logging.getLogger(__name__) or self.logger) and replace the prints (the occurrences that print DFlash mask_token_id) with logger.debug or logger.info as appropriate; ensure the log message includes the same context (e.g., "DFlash mask_token_id: %s") and, if desired, gate emission behind a debug flag or log level so these messages don't appear in production.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tools/launcher/common/dflash/ar_eval_mtbench.sh`:
- Around line 147-163: The script currently hardcodes trust_remote_code=True in
calls to AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained
(for MODEL), allowing execution of untrusted HF repo code; add a CLI boolean
flag (e.g., --trust_remote_code defaulting to False) parsed at startup and pass
that flag's value into both AutoTokenizer.from_pretrained(...) and
AutoModelForCausalLM.from_pretrained(...), ensuring MODEL load calls use
trust_remote_code=trust_remote_code_flag so remote code execution is opt-in.
- Around line 104-123: The Python snippet currently injects shell variables
(MODEL, LAST_CKPT, MASK_TOKEN_ID, ONLINE, etc.) directly into the python -c
string and hardcodes trust_remote_code=True; change it to read those values from
os.environ or sys.argv inside the Python block (e.g., os.environ['MODEL'],
os.environ['LAST_CKPT'], os.environ.get('MASK_TOKEN_ID')) instead of
interpolating into the source, and remove the hardcoded trust_remote_code=True
from AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained so
the flag is controlled by an explicit parameter or environment variable
(defaulting to False) that callers can opt into.
In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 53-127: The script currently injects shell variables directly into
the embedded Python code (HF_MODEL_CKPT, DFLASH_BLOCK_SIZE, DFLASH_NUM_LAYERS,
MASK_ARG, DFLASH_CKPT, NUM_SAMPLES), which risks code injection and quoting
errors; change the launcher to pass these values via environment variables or
explicit CLI args and read them inside the Python block (e.g., os.environ or
argparse) and remove all ${...} interpolation from the python3 -c string, then
parse/convert types (ints, dict parts) inside Python before using them (used
when calling mtsp.convert and loading checkpoints and dataset sampling). Also
make trust_remote_code configurable instead of hardcoding True in
AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained (add a
boolean env/arg like TRUST_REMOTE_CODE defaulting to False and pass it into both
calls), so loading remote code requires explicit opt-in.
---
Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 311-334: The fallback path sets trainer.state from
trainer_state.json but then calls trainer.train() with no resume flag, which
restarts from step 0 and replays batches; change the fallback to call
trainer.train(resume_from_checkpoint=checkpoint) so the Trainer resumes at the
saved step, and also clear any existing optimizer (e.g., set trainer.optimizer =
None) before training to ensure a fresh optimizer when you intentionally fall
back from an optimizer-state mismatch; keep the existing use of
trainer.state.load_from_json(state_file) and the checkpoint/state_file
variables.
In `@examples/speculative_decoding/train_dflash.py`:
- Around line 150-153: Add a new CLI boolean flag (e.g., args.trust_remote_code
defaulting to False) and pass it to both AutoModelForCausalLM.from_pretrained
and AutoTokenizer.from_pretrained instead of hardcoding trust_remote_code=True;
update the argument parser where args.model is defined to include
"--trust-remote-code" (store_true) and then change the two calls
(AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained) to use
trust_remote_code=args.trust_remote_code so remote-code execution remains
opt-in.
- Around line 292-293: The script currently hardcodes the MT-Bench dataset path
when creating ds with load_dataset; make it configurable via a CLI flag and env
var with a sensible default: add an argparse option (e.g. --mtbench-dataset)
that falls back to os.environ.get("MT_BENCH_DATASET") and then to the default
string "HuggingFaceH4/mt_bench_prompts"; replace the literal in the load_dataset
call (the line that sets ds = load_dataset(... )["train"]) to use that resolved
dataset identifier so HFARValidation(raw_model, tokenizer) continues to run
against the user-provided or default dataset.
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 272-316: The config currently hardcodes "torch_dtype": "bfloat16"
in _export_config(), causing a mismatch when export(dtype=...) is used; update
_export_config to determine the dtype dynamically (prefer the explicit export
dtype passed to export, e.g., self.export_dtype or self.dtype if present) and
set "torch_dtype" to that dtype's string (falling back to the current bfloat16
value if no export dtype is available); apply the same change to the other
config block around lines 328-343 so the advertised dtype matches the exported
safetensors.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 83-84: The module-level globals _MLP_CLS, _NORM_CLS, _ROTARY_CLS,
and _rotate_half cause cross-model interference; change the code to resolve
these per-model and store them on the model instance (e.g., self._mlp_cls,
self._norm_cls, self._rotary_cls, self._rotate_half) inside modify()/convert()
instead of reassigning the module globals. Update DFlashModule.__init__ to
accept or read instance attributes (self._mlp_cls, self._norm_cls, etc.) and
pass them into DFlashAttention and apply_rotary_pos_emb so those
functions/classes no longer reference module-level names; ensure all call sites
(modify/convert, DFlashModule creation, DFlashAttention usage, and
apply_rotary_pos_emb) are updated to thread the instance-specific components
through.
- Around line 832-845: The debug block uses base_token.item() which fails for
batch size >1; update the debug printing in hf_dflash.py (the block that sets
self._psg_debug and prints using base_token and mask_token_id) to use a
batch-safe representation instead of .item() — e.g., convert the tensor to CPU,
detach it and render as a list (base_token.detach().cpu().view(-1).tolist() or
base_token.detach().cpu()[:,0].tolist()) or print only the first batch element
(base_token[0].item()) so it won't raise for B>1; keep other debug prints
(th_dbg, seq_len, dflash_block_size, target_layer_ids) unchanged.
In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 391-418: The current filtering drops samples without an assistant
turn regardless of loss mode; only skip samples when self.answer_only_loss is
True. Modify the branch that checks for assistant turns (the blocks handling
messages and conversations and calling _sharegpt_to_openai_messages) to perform
the "no assistant turn -> print warning and continue" only when
self.answer_only_loss is True; otherwise append the messages/conversations
as-is. Ensure references: messages, conversations, converted (from
_sharegpt_to_openai_messages), batch, and self.answer_only_loss are used to gate
the skip logic so dummy batch creation behavior only occurs when
answer_only_loss is enabled.
- Around line 156-157: The answer-only template rewrite in the constructor
(triggered by answer_only_loss calling _ensure_generation_tags) incorrectly
converts multimodal messages into text-only content, breaking
VisionLanguageDataCollator and processor.apply_chat_template; fix by detecting
multimodal inputs (e.g., presence of non-text keys or list-of-blocks/image/block
structures in message["content"] used by VisionLanguageDataCollator) and skip
the answer-only template rewrite for those cases, i.e., in
_ensure_generation_tags (or the place where answer_only_loss is handled) add a
guard that returns early when message content is multimodal so
processor.apply_chat_template receives the original multimodal structure. Ensure
the check references the same shapes used by VisionLanguageDataCollator (e.g.,
'blocks' or image fields) so multimodal batches are preserved.
- Around line 350-356: The assistant_mask is aligned to original input positions
but labels are next-token targets (labels[..., :-1] = input_ids[..., 1:]), so
you must shift assistant_masks into label-space before masking; compute a
shifted mask (e.g., shifted = assistant_mask[..., 1:]) and apply it to
labels[..., :-1] (or equivalently mask labels where shifted == 0) instead of
using the unshifted assistant_mask; keep the existing checks for torch.Tensor
and .any() and set masked positions to IGNORE_TOKEN_ID when answer_only_loss is
true.
In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 111-113: Replace the offline validation call to
validator.validate(...) with the online/context-dependent path by invoking
validator.validate_online(...) in the loop; keep the same arguments (e.g.,
osl=32, input_ids=input_ids, steps=3) and continue to append the returned ar to
ars so the script measures DFlash’s context-dependent AR instead of comparing to
fixed ground truth.
- Around line 63-66: The code currently passes trust_remote_code=True into
AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained for
HF_MODEL_CKPT; change this to be controlled by a new environment flag (e.g.,
TRUST_REMOTE_CODE or ENABLE_TRUST_REMOTE_CODE) that defaults to False, parse it
as a boolean, and pass that variable into the trust_remote_code parameter of
both AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained so
remote-code execution is opt-in.
In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 30-35: The script uses unquoted shell variables and an unquoted
version specifier which can cause globbing/word-splitting and shell redirection:
quote ${SCRIPT_DIR} when sourcing service_utils.sh (refer to SCRIPT_DIR and
service_utils.sh) and quote the pip requirement string so the shell doesn't
treat >= as a redirection (refer to the pip install command for
huggingface-hub), and also quote PATH expansions when exporting (refer to the
export PATH line) to avoid word-splitting.
- Around line 181-223: The heredoc currently hardcodes trust_remote_code=True
and interpolates shell variables directly, which is unsafe; change the script to
pass HF_MODEL_CKPT, DFLASH_BLOCK_SIZE, DFLASH_NUM_LAYERS, MASK_ARG, AR_CKPT and
a new TRUST_REMOTE_CODE via environment variables and read them inside Python
with os.environ.get() (casting block/num to int and TRUST_REMOTE_CODE to a
boolean defaulting to False) before calling AutoModelForCausalLM.from_pretrained
and AutoTokenizer.from_pretrained; remove all ${...} interpolations from the
Python snippet, use the env-derived variables when building the dflash config
and when loading checkpoints (model.load_state_dict and
model.dflash_module.load_state_dict), and ensure a safe default
trust_remote_code=False unless the env explicitly sets it to "true".
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Line 499: Replace the unconditional print calls that output the mask token id
with Python logging: create/get a module/class logger (eg.
logging.getLogger(__name__) or self.logger) and replace the prints (the
occurrences that print DFlash mask_token_id) with logger.debug or logger.info as
appropriate; ensure the log message includes the same context (e.g., "DFlash
mask_token_id: %s") and, if desired, gate emission behind a debug flag or log
level so these messages don't appear in production.
In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 149-152: The shell command invoking export_hf_checkpoint uses
unquoted path variables which will break on spaces; update the invocation that
references OUTPUT_DIR and EXPORT_DIR (the python3 ... --model_path ${OUTPUT_DIR}
--export_path ${EXPORT_DIR} || ...) to wrap both variables in double quotes
(e.g. --model_path "${OUTPUT_DIR}" --export_path "${EXPORT_DIR}") so paths with
spaces are handled correctly and the fallback echo behavior remains unchanged.
- Around line 124-128: The command invocation to launch_train.sh should quote
CONFIG_FILE and NUM_NODES to protect against spaces and edge cases: update the
call to use "--config \"${CONFIG_FILE}\"" and "--num_nodes \"${NUM_NODES:-1}\"";
keep OVERRIDES unquoted for word-splitting but preferably refactor OVERRIDES
into an array (e.g., OVERRIDES_ARGS) and expand it safely (e.g.,
"${OVERRIDES_ARGS[@]}") when invoking launch_train.sh so arguments are passed
reliably; adjust the invocation around launch_train.sh and the variables
CONFIG_FILE, NUM_NODES, and OVERRIDES accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6e6d829b-957a-4ee3-a64b-949827a0b751
📒 Files selected for processing (25)
doc/results/dflash_results.htmlexamples/speculative_decoding/README.mdexamples/speculative_decoding/doc/dflash_results.mdexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pyexamples/speculative_decoding/train_dflash.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/__init__.pymodelopt/torch/speculative/dflash/conversion.pymodelopt/torch/speculative/dflash/default_config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/mode.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/utils.pymodelopt/torch/utils/plugins/transformers_dataset.pymodelopt_recipes/general/speculative_decoding/dflash.yamltests/gpu/torch/speculative/plugins/test_hf_dflash.pytests/unit/torch/speculative/plugins/test_hf_dflash.pytools/launcher/common/dflash/ar_eval_mtbench.shtools/launcher/common/dflash/ar_validate.shtools/launcher/common/dflash/online_training.shtools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
✅ Files skipped from review due to trivial changes (5)
- modelopt/torch/speculative/dflash/default_config.py
- examples/speculative_decoding/README.md
- tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
- modelopt_recipes/general/speculative_decoding/dflash.yaml
- examples/speculative_decoding/doc/dflash_results.md
🚧 Files skipped from review as they are similar to previous changes (9)
- examples/speculative_decoding/scripts/export_hf_checkpoint.py
- modelopt/torch/speculative/plugins/init.py
- modelopt/torch/speculative/dflash/init.py
- modelopt/torch/speculative/mode.py
- modelopt/torch/speculative/config.py
- modelopt/torch/speculative/utils.py
- modelopt/torch/speculative/dflash/conversion.py
- tests/unit/torch/speculative/plugins/test_hf_dflash.py
- examples/speculative_decoding/eagle_utils.py
DFlash (Block Diffusion for Flash Speculative Decoding) predicts an entire block of tokens in a single forward pass using masked parallel prediction with KV injection from the target model's hidden states. Key features: - Feature fusion (multi-layer hidden states -> FC + RMSNorm) - KV injection (fused features as K/V in every draft layer with QK-norm) - Random anchor sampling with bidirectional intra-block attention - Logit distillation with exponential loss decay (gamma weighting) - Multi-node DDP training with checkpoint resume - Export to z-lab compatible HF format - Online validation (context-dependent ground truth) Training recipe: modelopt_recipes/general/speculative_decoding/dflash.yaml Results: examples/speculative_decoding/doc/dflash_results.md Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
e45cc37 to
5f8d004
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)
287-301:⚠️ Potential issue | 🔴 Critical
medusamode still reaches the trainer withdata_moduleunset.
data_moduleis only initialized inside the("eagle3", "dflash")branch, butTrainingArguments.modeand the conversion block still accept"medusa". A medusa run will hitEagleTrainerWithAccLog(..., **data_module)withdata_moduleundefined.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 287 - 301, The code only sets data_module inside the ("eagle3", "dflash") branch so a run with training_args.mode == "medusa" leaves data_module undefined and crashes when passed into EagleTrainerWithAccLog; fix by initializing data_module before the conditional (e.g. data_module = {}) and either add an explicit branch to create a medusa-specific module (call the appropriate builder if one exists) or raise a clear error for unsupported modes, making sure the symbols training_args.mode, make_eagle_supervised_data_module, and EagleTrainerWithAccLog are updated accordingly.
♻️ Duplicate comments (7)
tools/launcher/common/dflash/ar_eval_mtbench.sh (2)
147-163:⚠️ Potential issue | 🔴 CriticalMake remote code execution opt-in for MT-Bench eval.
Both
AutoTokenizer.from_pretrained()andAutoModelForCausalLM.from_pretrained()hardcodetrust_remote_code=True. That executes arbitrary repository code fromHF_MODEL_CKPTon the launcher node. Thread this through a flag or env var and default it toFalse.As per coding guidelines, "Do not hardcode
trust_remote_code=Truewhen loading Hugging Face Transformers models. Let the caller decide via a parameter; default toFalse."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/ar_eval_mtbench.sh` around lines 147 - 163, The code currently hardcodes trust_remote_code=True when calling AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained (with MODEL), which enables remote code execution; add a configurable flag (e.g., a function/CLI param or env var like TRUST_REMOTE_CODE defaulting to False) and pass that flag into both AutoTokenizer.from_pretrained(...) and AutoModelForCausalLM.from_pretrained(...). Ensure the new flag is read early (before tokenizer/model load) and used in the model load call that also includes ATTN_IMPL/device_map so behavior is consistent and safe by default.
104-123:⚠️ Potential issue | 🔴 CriticalDon’t splice shell variables directly into the
python -csource.
MODEL,LAST_CKPT,MASK_TOKEN_ID, andONLINEare interpolated straight into Python literals. A checkpoint path containing a quote breaks the script, and a user-controlled value can turn into arbitrary Python. Pass these values via env vars or argv and read them inside the Python block instead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/ar_eval_mtbench.sh` around lines 104 - 123, The current python -c block injects shell variables directly into Python literals (see MODEL, CKPT_PATH/ LAST_CKPT, MASK_TOKEN_ID_STR, ONLINE) which can break on quotes or allow code injection; instead pass those values via environment variables or command-line args and read them inside the Python snippet (use os.environ.get(...) or sys.argv parsing), convert MASK_TOKEN_ID and numeric flags (BLOCK_SIZE, NUM_LAYERS, OSL, STEPS) to ints and ONLINE to a boolean safely, and replace the direct interpolations in the python -c source with references to the env/argv variables to eliminate quoting/injection issues when loading the checkpoint or interpreting flags.modelopt/torch/utils/plugins/transformers_dataset.py (3)
351-356:⚠️ Potential issue | 🟠 MajorShift
assistant_masksinto label space before masking.Line 351 pre-shifts
labelsfor next-token prediction, but Line 356 applies the unshifted tokenizer mask. That drops the first assistant token from loss and keeps one token after the assistant span.🛠️ Suggested change
if self.answer_only_loss: if "assistant_masks" in tokenized_examples: assistant_mask = tokenized_examples["assistant_masks"] if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): - labels[assistant_mask == 0] = IGNORE_TOKEN_ID + shifted_assistant_mask = torch.zeros_like(assistant_mask) + shifted_assistant_mask[..., :-1] = assistant_mask[..., 1:] + labels[shifted_assistant_mask == 0] = IGNORE_TOKEN_ID else: # All assistant content truncated or no assistant in batch — mask all labels[:] = IGNORE_TOKEN_ID🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 351 - 356, The code shifts labels for next-token prediction (labels[..., :-1] = input_ids[..., 1:]) but then applies the unshifted assistant mask, causing misalignment; update the masking to shift assistant_masks into label space before applying IGNORE_TOKEN_ID (e.g., use assistant_mask[..., :-1] or equivalent) so that assistant token positions align with labels when setting labels[assistant_mask == 0] = IGNORE_TOKEN_ID; adjust the block around labels, self.answer_only_loss, and tokenized_examples["assistant_masks"] accordingly to use the shifted mask.
220-271:⚠️ Potential issue | 🟠 MajorThese fallback templates are still text-only.
All three variants concatenate
message["content"]as a string, butVisionLanguageDataCollatorturns multimodal content into block lists before templating. Enablinganswer_only_losson that path will misformat or fail batches. Gate this rewrite to text-only collators or provide multimodal-safe templates.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 220 - 271, The fallback _GENERATION_TEMPLATES are text-only but are applied to multimodal inputs by VisionLanguageDataCollator (especially when answer_only_loss is used), causing misformatted batches; update the code that selects or applies _GENERATION_TEMPLATES so it only uses these string-concatenating templates for text-only messages (e.g., detect collator type or check message['content'] is a str/does not contain block lists) and either (a) provide multimodal-safe templates/serialization for non-text message['content'] or (b) raise/explicitly gate with a clear error when VisionLanguageDataCollator or answer_only_loss would feed non-str content into _GENERATION_TEMPLATES (reference _GENERATION_TEMPLATES, VisionLanguageDataCollator, and answer_only_loss to locate where to implement the guard).
391-409:⚠️ Potential issue | 🟠 MajorOnly drop assistant-free chats in answer-only mode.
This now skips prompt-only/system-user samples even when
answer_only_loss=False, which changes the generic collator behavior and can collapse a whole batch into the dummy sample. The assistant-turn filter should only run in answer-only mode.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 391 - 409, The current logic in the collator (the block handling example.get("messages") and example.get("conversations") and calling _sharegpt_to_openai_messages) unconditionally drops samples with no assistant turn; change it so the "no assistant turn" checks (the any(m.get("role") == "assistant") guards and the print_rank_0 warnings) only run when answer_only_loss is True. Concretely, update the branches around messages/conversations in the collator to check answer_only_loss before skipping or warning (leave normal batching behavior intact when answer_only_loss is False), referencing the variables/messages, conversations, _sharegpt_to_openai_messages, and print_rank_0 to locate and modify the conditions.examples/speculative_decoding/main.py (1)
317-334:⚠️ Potential issue | 🟠 MajorThe optimizer-mismatch fallback still replays training data from step 0.
Loading
trainer_state.jsonintotrainer.stateis not enough by itself. Because the retry dropsresume_from_checkpoint, this branch logs a resumed step but restarts the input pipeline from the beginning. Keep the checkpoint on the second call and bypass only optimizer/scheduler restoration.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 317 - 334, The retry branch currently calls trainer.train() without resume_from_checkpoint so the data pipeline restarts; instead call trainer.train(resume_from_checkpoint=checkpoint) and ensure only optimizer/scheduler state is bypassed by clearing or reinitializing those objects after loading trainer.state: after trainer.state = trainer.state.load_from_json(state_file) explicitly set trainer.optimizer = None and trainer.lr_scheduler = None (or reinitialize them as appropriate) so checkpoint is used for data/resume but optimizer/scheduler are not restored.modelopt/torch/export/plugins/hf_spec_export.py (1)
272-318:⚠️ Potential issue | 🟠 MajorKeep
config.jsondtype aligned with the exported tensors.
export(dtype=...)castsmodel.safetensors, but_export_config()still writes"torch_dtype": "bfloat16". fp16/fp32 exports will therefore advertise the wrong dtype to downstream loaders.🛠️ Suggested change
- def _export_config(self): + def _export_config(self, dtype: torch.dtype | None = None): """Build config.json matching z-lab DFlash format.""" model = self.model base_config = ( getattr(model.config, "text_config", None) or getattr(model.config, "llm_config", None) @@ "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), "rope_theta": getattr(base_config, "rope_theta", 1000000.0), "rope_scaling": getattr(base_config, "rope_scaling", None), "tie_word_embeddings": False, - "torch_dtype": "bfloat16", + "torch_dtype": ( + str(dtype).replace("torch.", "") + if dtype is not None + else str( + getattr( + draft_config, + "torch_dtype", + getattr(base_config, "torch_dtype", torch.bfloat16), + ) + ).replace("torch.", "") + ), "num_target_layers": getattr(base_config, "num_hidden_layers", 36), } @@ - drafter_config = self._export_config() + drafter_config = self._export_config(dtype=dtype) with open(f"{export_dir}/config.json", "w") as f: json.dump(drafter_config, f, indent=2)Also applies to: 329-343
🧹 Nitpick comments (8)
examples/speculative_decoding/train_dflash.py (3)
264-270: Guard against missingtrain_accattribute.If the model output lacks
train_acc, accessingoutput.train_acc[0][0]before thehasattrcheck completes would fail. The current code handles this correctly withhasattr, but consider usinggetattrfor cleaner extraction.- acc = output.train_acc[0][0] if hasattr(output, "train_acc") else 0.0 + acc = getattr(output, "train_acc", [[0.0]])[0][0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/train_dflash.py` around lines 264 - 270, Replace the hasattr pattern with a safe getattr extraction so acc is read atomically: use getattr(output, "train_acc", None) to fetch train_acc, handle None by defaulting acc to 0.0, and keep using scheduler.get_last_lr()[0] and print_rank0 for logging; update the block around global_step / args.log_interval where acc is computed to reference output.train_acc via getattr to avoid any race or attribute-access issues.
102-108: Silent exception swallowing hides data pipeline errors.The bare
except Exception:discards the error details. Logging the exception would help diagnose tokenization or parsing failures during debugging.Proposed fix
try: input_ids, loss_mask = parser.parse(convs, max_length=max_length) processed["input_ids"].append(input_ids) processed["loss_mask"].append(loss_mask) - except Exception: + except Exception as e: + if is_rank0(): + print(f"Skipping sample: {e}") skipped += 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/train_dflash.py` around lines 102 - 108, The try/except in the loop around parser.parse(convs, max_length=max_length) is swallowing errors silently; change the bare except to catch Exception as e and log the exception (e.g., using logger.exception or logger.error(..., exc_info=True)) before incrementing skipped so tokenization/parsing failures are recorded for debugging while preserving the existing behavior that appends to processed["input_ids"]/["loss_mask"] only on success; update the block around parser.parse and the skipped increment accordingly.
207-212: Consider implications offind_unused_parameters=Trueon performance.This setting adds overhead by tracking parameter usage each iteration. It's necessary here since only
dflash_moduleparameters are trained, but worth documenting for future maintainers.# Wrap with DDP + # find_unused_parameters=True needed because only dflash_module is trained model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], find_unused_parameters=True, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/train_dflash.py` around lines 207 - 212, The DDP wrapper uses find_unused_parameters=True which incurs runtime overhead; update the DistributedDataParallel instantiation (torch.nn.parallel.DistributedDataParallel for the model variable) to include a clear inline comment explaining that find_unused_parameters=True is required because only dflash_module parameters are being trained (so many parameters remain unused), and add a TODO/docs note suggesting to set it to False when all model parameters are trained or to conditionally set this flag based on whether only dflash_module is being optimized; keep the DDP call otherwise unchanged.tools/launcher/common/dflash/online_training.sh (2)
124-128: Quote variables to prevent word splitting and globbing.Multiple unquoted variables could cause issues with paths containing spaces or special characters.
Proposed fix
bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ - --config ${CONFIG_FILE} \ - --num_nodes ${NUM_NODES:-1} \ - --head_node_ip ${HEAD_NODE_IP:-} \ - ${OVERRIDES} + --config "${CONFIG_FILE}" \ + --num_nodes "${NUM_NODES:-1}" \ + --head_node_ip "${HEAD_NODE_IP:-}" \ + ${OVERRIDES} # OVERRIDES intentionally unquoted for word splitting🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/online_training.sh` around lines 124 - 128, The invocation of launch_train.sh in online_training.sh uses unquoted shell variables which can cause word-splitting and globbing; update the call to quote variables like "${CONFIG_FILE}", "${NUM_NODES:-1}", "${HEAD_NODE_IP:-}" and "${OVERRIDES}" (or handle OVERRIDES as an array if it may contain multiple args) so the command in the launch_train.sh call is robust to spaces and special characters.
31-31: Quote variable to prevent word splitting.-source ${SCRIPT_DIR}/../service_utils.sh +source "${SCRIPT_DIR}/../service_utils.sh"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/online_training.sh` at line 31, The source command uses an unquoted SCRIPT_DIR which can cause word-splitting for paths with spaces; update the invocation that sources service_utils.sh to quote the SCRIPT_DIR expansion (i.e. use the quoted form of the existing source command) so the path is treated as a single token when calling source "${SCRIPT_DIR}/../service_utils.sh".modelopt/torch/speculative/plugins/hf_dflash.py (2)
362-410: Mask token ID detection relies on fragile heuristics.The auto-detection uses hardcoded offsets (26, 25, 24) and magic numbers (128002 for Llama3, vocab_size thresholds). These assumptions may break with new model versions.
Consider:
- Adding a warning when falling back to heuristics
- Documenting the known model-specific token IDs
- Encouraging explicit
mask_token_idconfiguration🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 362 - 410, The _auto_detect_mask_token_id function relies on fragile hardcoded heuristics (offsets [26,25,24], magic 128002, vocab_size thresholds); update it to log a warning via the module logger when any heuristic/fallback path is used (e.g., when returning a candidate from offsets, the vocab_size heuristics, falling back to 128002, pad/eos or final fallback), add a short docstring comment inside _auto_detect_mask_token_id enumerating the known model-specific IDs (Qwen mask region, Llama3 reserved_special_token_0) and clarify they are heuristics, and update any public-facing docs or function docstring to recommend explicitly supplying mask_token_id in config (mention symbol base_config.mask_token_id) so callers can avoid autodetection.
499-499: Replace unconditional print statements with proper logging.Direct
print()calls bypass the logging framework, making it harder to control output verbosity in production.Proposed fix
+import logging + +logger = logging.getLogger(__name__) + # In modify(): - print(f"DFlash mask_token_id: {self.mask_token_id}") + logger.info(f"DFlash mask_token_id: {self.mask_token_id}") # Later: - print(f"DFlash: using {original_cls.__name__}.forward as base forward") + logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")Also applies to: 548-548
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 499, Replace the unconditional print statements that output the mask token id (e.g. print(f"DFlash mask_token_id: {self.mask_token_id}")) with calls to the logging framework; obtain a logger (module-level logging.getLogger(__name__) or reuse an existing self.logger if the class provides one) and emit the message at an appropriate level (debug/info). Ensure the logging import is added if missing and update both occurrences (the one referencing self.mask_token_id and the other at the noted second location) to use logger.debug(...) or logger.info(...) instead of print so output respects configured log levels.tools/launcher/common/dflash/ar_validate.sh (1)
30-31: Quote variable to prevent word splitting.The unquoted
${SCRIPT_DIR}could cause issues with paths containing spaces.-source ${SCRIPT_DIR}/../service_utils.sh +source "${SCRIPT_DIR}/../service_utils.sh"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/ar_validate.sh` around lines 30 - 31, The source invocation uses an unquoted variable which can break on paths with spaces; update the source command to quote SCRIPT_DIR (use "${SCRIPT_DIR}/../service_utils.sh") so the shell treats the path as a single token and keep the trap 'error_handler $0 $LINENO' ERR unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 156-157: The bug is that _ensure_generation_tags() mutates
tokenizer.chat_template in-place, causing later apply_chat_template() calls
(used by trainer/validation) to see the mutated template; change the
implementation so it does not rewrite the shared tokenizer template in place —
instead operate on a shallow copy of the chat template (or create a new template
object/string) and assign that copy to the collator/processing_class or return
it from _ensure_generation_tags() without modifying tokenizer.chat_template;
ensure code paths that use the generated tags (answer_only_loss branch,
_ensure_generation_tags, and apply_chat_template) consume the copied template
rather than the original tokenizer.chat_template so reuse of the tokenizer
elsewhere remains unchanged.
In `@tools/launcher/common/dflash/ar_validate.sh`:
- Line 102: Replace the hardcoded internal dataset path in the load_dataset call
so it reads the dataset source from an environment variable (e.g.,
DATASET_SOURCE) with a fallback to the public "HuggingFaceH4/mt_bench_prompts";
update the line that calls load_dataset (the expression assigning ds) to use
process.env.DATASET_SOURCE (or equivalent shell/env expansion) and default to
"HuggingFaceH4/mt_bench_prompts" when the env var is not set, ensuring
portability across environments.
In `@tools/launcher/common/dflash/online_training.sh`:
- Line 228: Replace the hardcoded dataset path in the load_dataset call with an
environment-variable-driven value: read an env var (e.g., MT_BENCH_DATASET) with
a public fallback (for example "HuggingFaceH4/mt_bench_prompts") and pass that
variable into the load_dataset invocation (the line that calls
load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')). Ensure the script
documents the env var and uses the fallback when the env var is unset.
---
Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 287-301: The code only sets data_module inside the ("eagle3",
"dflash") branch so a run with training_args.mode == "medusa" leaves data_module
undefined and crashes when passed into EagleTrainerWithAccLog; fix by
initializing data_module before the conditional (e.g. data_module = {}) and
either add an explicit branch to create a medusa-specific module (call the
appropriate builder if one exists) or raise a clear error for unsupported modes,
making sure the symbols training_args.mode, make_eagle_supervised_data_module,
and EagleTrainerWithAccLog are updated accordingly.
---
Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 317-334: The retry branch currently calls trainer.train() without
resume_from_checkpoint so the data pipeline restarts; instead call
trainer.train(resume_from_checkpoint=checkpoint) and ensure only
optimizer/scheduler state is bypassed by clearing or reinitializing those
objects after loading trainer.state: after trainer.state =
trainer.state.load_from_json(state_file) explicitly set trainer.optimizer = None
and trainer.lr_scheduler = None (or reinitialize them as appropriate) so
checkpoint is used for data/resume but optimizer/scheduler are not restored.
In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 351-356: The code shifts labels for next-token prediction
(labels[..., :-1] = input_ids[..., 1:]) but then applies the unshifted assistant
mask, causing misalignment; update the masking to shift assistant_masks into
label space before applying IGNORE_TOKEN_ID (e.g., use assistant_mask[..., :-1]
or equivalent) so that assistant token positions align with labels when setting
labels[assistant_mask == 0] = IGNORE_TOKEN_ID; adjust the block around labels,
self.answer_only_loss, and tokenized_examples["assistant_masks"] accordingly to
use the shifted mask.
- Around line 220-271: The fallback _GENERATION_TEMPLATES are text-only but are
applied to multimodal inputs by VisionLanguageDataCollator (especially when
answer_only_loss is used), causing misformatted batches; update the code that
selects or applies _GENERATION_TEMPLATES so it only uses these
string-concatenating templates for text-only messages (e.g., detect collator
type or check message['content'] is a str/does not contain block lists) and
either (a) provide multimodal-safe templates/serialization for non-text
message['content'] or (b) raise/explicitly gate with a clear error when
VisionLanguageDataCollator or answer_only_loss would feed non-str content into
_GENERATION_TEMPLATES (reference _GENERATION_TEMPLATES,
VisionLanguageDataCollator, and answer_only_loss to locate where to implement
the guard).
- Around line 391-409: The current logic in the collator (the block handling
example.get("messages") and example.get("conversations") and calling
_sharegpt_to_openai_messages) unconditionally drops samples with no assistant
turn; change it so the "no assistant turn" checks (the any(m.get("role") ==
"assistant") guards and the print_rank_0 warnings) only run when
answer_only_loss is True. Concretely, update the branches around
messages/conversations in the collator to check answer_only_loss before skipping
or warning (leave normal batching behavior intact when answer_only_loss is
False), referencing the variables/messages, conversations,
_sharegpt_to_openai_messages, and print_rank_0 to locate and modify the
conditions.
In `@tools/launcher/common/dflash/ar_eval_mtbench.sh`:
- Around line 147-163: The code currently hardcodes trust_remote_code=True when
calling AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained
(with MODEL), which enables remote code execution; add a configurable flag
(e.g., a function/CLI param or env var like TRUST_REMOTE_CODE defaulting to
False) and pass that flag into both AutoTokenizer.from_pretrained(...) and
AutoModelForCausalLM.from_pretrained(...). Ensure the new flag is read early
(before tokenizer/model load) and used in the model load call that also includes
ATTN_IMPL/device_map so behavior is consistent and safe by default.
- Around line 104-123: The current python -c block injects shell variables
directly into Python literals (see MODEL, CKPT_PATH/ LAST_CKPT,
MASK_TOKEN_ID_STR, ONLINE) which can break on quotes or allow code injection;
instead pass those values via environment variables or command-line args and
read them inside the Python snippet (use os.environ.get(...) or sys.argv
parsing), convert MASK_TOKEN_ID and numeric flags (BLOCK_SIZE, NUM_LAYERS, OSL,
STEPS) to ints and ONLINE to a boolean safely, and replace the direct
interpolations in the python -c source with references to the env/argv variables
to eliminate quoting/injection issues when loading the checkpoint or
interpreting flags.
---
Nitpick comments:
In `@examples/speculative_decoding/train_dflash.py`:
- Around line 264-270: Replace the hasattr pattern with a safe getattr
extraction so acc is read atomically: use getattr(output, "train_acc", None) to
fetch train_acc, handle None by defaulting acc to 0.0, and keep using
scheduler.get_last_lr()[0] and print_rank0 for logging; update the block around
global_step / args.log_interval where acc is computed to reference
output.train_acc via getattr to avoid any race or attribute-access issues.
- Around line 102-108: The try/except in the loop around parser.parse(convs,
max_length=max_length) is swallowing errors silently; change the bare except to
catch Exception as e and log the exception (e.g., using logger.exception or
logger.error(..., exc_info=True)) before incrementing skipped so
tokenization/parsing failures are recorded for debugging while preserving the
existing behavior that appends to processed["input_ids"]/["loss_mask"] only on
success; update the block around parser.parse and the skipped increment
accordingly.
- Around line 207-212: The DDP wrapper uses find_unused_parameters=True which
incurs runtime overhead; update the DistributedDataParallel instantiation
(torch.nn.parallel.DistributedDataParallel for the model variable) to include a
clear inline comment explaining that find_unused_parameters=True is required
because only dflash_module parameters are being trained (so many parameters
remain unused), and add a TODO/docs note suggesting to set it to False when all
model parameters are trained or to conditionally set this flag based on whether
only dflash_module is being optimized; keep the DDP call otherwise unchanged.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 362-410: The _auto_detect_mask_token_id function relies on fragile
hardcoded heuristics (offsets [26,25,24], magic 128002, vocab_size thresholds);
update it to log a warning via the module logger when any heuristic/fallback
path is used (e.g., when returning a candidate from offsets, the vocab_size
heuristics, falling back to 128002, pad/eos or final fallback), add a short
docstring comment inside _auto_detect_mask_token_id enumerating the known
model-specific IDs (Qwen mask region, Llama3 reserved_special_token_0) and
clarify they are heuristics, and update any public-facing docs or function
docstring to recommend explicitly supplying mask_token_id in config (mention
symbol base_config.mask_token_id) so callers can avoid autodetection.
- Line 499: Replace the unconditional print statements that output the mask
token id (e.g. print(f"DFlash mask_token_id: {self.mask_token_id}")) with calls
to the logging framework; obtain a logger (module-level
logging.getLogger(__name__) or reuse an existing self.logger if the class
provides one) and emit the message at an appropriate level (debug/info). Ensure
the logging import is added if missing and update both occurrences (the one
referencing self.mask_token_id and the other at the noted second location) to
use logger.debug(...) or logger.info(...) instead of print so output respects
configured log levels.
In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 30-31: The source invocation uses an unquoted variable which can
break on paths with spaces; update the source command to quote SCRIPT_DIR (use
"${SCRIPT_DIR}/../service_utils.sh") so the shell treats the path as a single
token and keep the trap 'error_handler $0 $LINENO' ERR unchanged.
In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 124-128: The invocation of launch_train.sh in online_training.sh
uses unquoted shell variables which can cause word-splitting and globbing;
update the call to quote variables like "${CONFIG_FILE}", "${NUM_NODES:-1}",
"${HEAD_NODE_IP:-}" and "${OVERRIDES}" (or handle OVERRIDES as an array if it
may contain multiple args) so the command in the launch_train.sh call is robust
to spaces and special characters.
- Line 31: The source command uses an unquoted SCRIPT_DIR which can cause
word-splitting for paths with spaces; update the invocation that sources
service_utils.sh to quote the SCRIPT_DIR expansion (i.e. use the quoted form of
the existing source command) so the path is treated as a single token when
calling source "${SCRIPT_DIR}/../service_utils.sh".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d4913aed-3f2b-413c-9f8a-98c6809e510e
📒 Files selected for processing (25)
doc/results/dflash_results.htmlexamples/speculative_decoding/README.mdexamples/speculative_decoding/doc/dflash_results.mdexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pyexamples/speculative_decoding/train_dflash.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/__init__.pymodelopt/torch/speculative/dflash/conversion.pymodelopt/torch/speculative/dflash/default_config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/mode.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/utils.pymodelopt/torch/utils/plugins/transformers_dataset.pymodelopt_recipes/general/speculative_decoding/dflash.yamltests/gpu/torch/speculative/plugins/test_hf_dflash.pytests/unit/torch/speculative/plugins/test_hf_dflash.pytools/launcher/common/dflash/ar_eval_mtbench.shtools/launcher/common/dflash/ar_validate.shtools/launcher/common/dflash/online_training.shtools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
✅ Files skipped from review due to trivial changes (5)
- examples/speculative_decoding/README.md
- modelopt/torch/speculative/dflash/init.py
- tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
- modelopt_recipes/general/speculative_decoding/dflash.yaml
- examples/speculative_decoding/doc/dflash_results.md
🚧 Files skipped from review as they are similar to previous changes (7)
- examples/speculative_decoding/scripts/export_hf_checkpoint.py
- modelopt/torch/speculative/dflash/default_config.py
- modelopt/torch/speculative/plugins/init.py
- modelopt/torch/speculative/mode.py
- modelopt/torch/speculative/utils.py
- modelopt/torch/speculative/dflash/conversion.py
- examples/speculative_decoding/eagle_utils.py
| if self.answer_only_loss: | ||
| self._ensure_generation_tags() |
There was a problem hiding this comment.
Don’t rewrite the shared tokenizer template in place.
_ensure_generation_tags() overwrites tokenizer.chat_template on the same tokenizer instance that is later reused as the trainer/validation processing_class. In DFlash mode that changes every later apply_chat_template() call, not just collation, so evaluation and generation can silently diverge from the model’s original prompt format.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 156 - 157,
The bug is that _ensure_generation_tags() mutates tokenizer.chat_template
in-place, causing later apply_chat_template() calls (used by trainer/validation)
to see the mutated template; change the implementation so it does not rewrite
the shared tokenizer template in place — instead operate on a shallow copy of
the chat template (or create a new template object/string) and assign that copy
to the collator/processing_class or return it from _ensure_generation_tags()
without modifying tokenizer.chat_template; ensure code paths that use the
generated tags (answer_only_loss branch, _ensure_generation_tags, and
apply_chat_template) consume the copied template rather than the original
tokenizer.chat_template so reuse of the tokenizer elsewhere remains unchanged.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
examples/speculative_decoding/main.py (1)
323-324:⚠️ Potential issue | 🟠 MajorThe fallback resume still doesn't properly resume the dataloader position.
Loading
trainer_state.jsonintotrainer.stateand then callingtrainer.train()withoutresume_from_checkpointcauses the dataloader to restart from step 0. The HuggingFace Trainer only resumes dataloader position whenresume_from_checkpointis passed totrainer.train().If the intent is a "fresh optimizer" resume (model weights + state, skip optimizer), this approach needs a different strategy—either patching the checkpoint to remove optimizer state before calling
trainer.train(resume_from_checkpoint=checkpoint), or accepting that dataloader replay from step 0 is the intended behavior and documenting it clearly.🛠️ Possible approaches
Option A: If dataloader replay is acceptable, add a comment explaining this tradeoff:
trainer.state = trainer.state.load_from_json(state_file) - trainer.train() + # Note: dataloader restarts from step 0; only trainer.state is restored + trainer.train()Option B: For true resume, remove/rename optimizer files before retrying:
+ # Remove optimizer state to allow resume without optimizer mismatch + optimizer_file = os.path.join(checkpoint, "optimizer.pt") + if os.path.isfile(optimizer_file): + os.rename(optimizer_file, optimizer_file + ".bak") trainer.state = trainer.state.load_from_json(state_file) - trainer.train() + trainer.train(resume_from_checkpoint=checkpoint)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 323 - 324, The current fallback replaces trainer.state from trainer_state.json and then calls trainer.train(), which does not restore the dataloader position; change the logic where trainer.state = trainer.state.load_from_json(state_file) followed by trainer.train() to one of two fixes: (A) call trainer.train(resume_from_checkpoint=checkpoint) so HF Trainer resumes dataloader/step position, or (B) implement a "fresh-optimizer" resume path that patches the checkpoint before calling trainer.train(resume_from_checkpoint=checkpoint) by removing/renaming optimizer-related files (optimizer.pt/scheduler state) so weights and trainer.state are used but optimizer state is skipped; also add a short comment explaining the chosen behavior.
🧹 Nitpick comments (3)
examples/speculative_decoding/main.py (1)
170-171: Consider movingimport jsonto module-level.The
jsonimport is placed inside the function. While this works, Python convention prefers module-level imports for standard library modules. This is a minor style preference.♻️ Suggested change
Move
import jsonto the top of the file with other imports (around line 32-47).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 170 - 171, The local "import json" inside the function should be moved to the module-level with the other imports at the top of examples/speculative_decoding/main.py: remove the in-function "import json" and add "import json" alongside the other imports near lines 32–47 so the standard-library import follows Python conventions and avoids repeated imports; update any references to json (e.g., in the function where it was imported) to use the top-level import.modelopt/torch/speculative/plugins/hf_dflash.py (2)
475-478: Infer the placement device without assuming.layers[-1].
_find_base_model_parts()probes several backbone layouts, but this placement path only works when the resolved base model exposes.layers. Pulling the device fromnext(self._base_model.parameters())or the embeddings module would match the broader probing logic and avoid conversion failures on other supported layouts.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 475 - 478, The placement of the DFlashModule assumes the base model has .layers; change the device resolution in the DFlashModule initialization so it does not rely on self._base_model.layers[-1]. Instead, determine the target device from the base model's parameters or embeddings—e.g., use next(self._base_model.parameters()).device (or the resolved embeddings module if present) when calling self.dflash_module.to(self._base_model.dtype).to(...); update references in the DFlashModule creation sequence (DFlashModule, self.dflash_module, self.dflash_config, and _base_model) accordingly so it matches the probing logic in _find_base_model_parts().
48-58: Add type hints to the new plugin entry points.This module adds several public helpers and runtime entry points without annotations, which leaves
mypyblind on a config- and tensor-heavy surface. Please type the arguments and return values for helpers likebuild_target_layer_ids()and the mainHFDFlashModelmethods before this lands.As per coding guidelines, "Ensure type hints are properly annotated for static type checking with mypy".
Also applies to: 401-402, 557-570, 761-762
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 48 - 58, Add static type hints for the public helpers and runtime entry points so mypy can check them: annotate build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> list[int] (or List[int]) and annotate apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] (or Tuple[torch.Tensor, torch.Tensor]); likewise add type annotations to the HFDFlashModel public methods referenced (the constructor and all methods around the 557-570 and 761-762 regions), using torch.Tensor for tensor params, int/float/bool for scalars, and Optional[...] or List[...] where appropriate, and import typing names (List, Optional, Tuple) as needed to satisfy mypy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/main.py`:
- Line 317: The line that reads state = json.load(open(state_file)) leaks a file
handle; change it to open the file using a context manager so the handle is
closed automatically (e.g., use with open(state_file) as f: then json.load(f)) —
update the code around the state and state_file usage (the assignment to state)
to use a with-block to ensure proper resource cleanup.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 48-55: The build_target_layer_ids function can produce negative or
out-of-range indices for tiny backbones; change it to special-case small models
and/or clamp every returned id into the valid range [0, num_target_layers - 1].
Specifically, in build_target_layer_ids ensure when num_target_layers < 4 you
return safe indices (e.g., center or 0..n-1) and after computing the list, map
each id to max(0, min(id, num_target_layers - 1)) so downstream logic that uses
lid + offset (the decoder embedding lookup) never receives a negative or
>=num_target_layers index.
- Around line 599-605: The forward call to the teacher/base model
(super().forward) is invoked while self.training may be True, so dropout remains
active; wrap the base-model forward in a context that sets the teacher to eval
mode (e.g., call model.eval() on the teacher/base instance) before calling
super().forward to produce base_outputs/target_hidden, and restore the original
training mode afterwards (use a try/finally or a small context manager) to
ensure deterministic hidden states during distillation while not changing the
overall module training flag.
---
Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 323-324: The current fallback replaces trainer.state from
trainer_state.json and then calls trainer.train(), which does not restore the
dataloader position; change the logic where trainer.state =
trainer.state.load_from_json(state_file) followed by trainer.train() to one of
two fixes: (A) call trainer.train(resume_from_checkpoint=checkpoint) so HF
Trainer resumes dataloader/step position, or (B) implement a "fresh-optimizer"
resume path that patches the checkpoint before calling
trainer.train(resume_from_checkpoint=checkpoint) by removing/renaming
optimizer-related files (optimizer.pt/scheduler state) so weights and
trainer.state are used but optimizer state is skipped; also add a short comment
explaining the chosen behavior.
---
Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 170-171: The local "import json" inside the function should be
moved to the module-level with the other imports at the top of
examples/speculative_decoding/main.py: remove the in-function "import json" and
add "import json" alongside the other imports near lines 32–47 so the
standard-library import follows Python conventions and avoids repeated imports;
update any references to json (e.g., in the function where it was imported) to
use the top-level import.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 475-478: The placement of the DFlashModule assumes the base model
has .layers; change the device resolution in the DFlashModule initialization so
it does not rely on self._base_model.layers[-1]. Instead, determine the target
device from the base model's parameters or embeddings—e.g., use
next(self._base_model.parameters()).device (or the resolved embeddings module if
present) when calling self.dflash_module.to(self._base_model.dtype).to(...);
update references in the DFlashModule creation sequence (DFlashModule,
self.dflash_module, self.dflash_config, and _base_model) accordingly so it
matches the probing logic in _find_base_model_parts().
- Around line 48-58: Add static type hints for the public helpers and runtime
entry points so mypy can check them: annotate
build_target_layer_ids(num_target_layers: int, num_draft_layers: int) ->
list[int] (or List[int]) and annotate apply_rotary_pos_emb(q: torch.Tensor, k:
torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor,
torch.Tensor] (or Tuple[torch.Tensor, torch.Tensor]); likewise add type
annotations to the HFDFlashModel public methods referenced (the constructor and
all methods around the 557-570 and 761-762 regions), using torch.Tensor for
tensor params, int/float/bool for scalars, and Optional[...] or List[...] where
appropriate, and import typing names (List, Optional, Tuple) as needed to
satisfy mypy.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: fd57b686-ff1c-4721-afd4-c52f8a65f250
📒 Files selected for processing (2)
examples/speculative_decoding/main.pymodelopt/torch/speculative/plugins/hf_dflash.py
| def build_target_layer_ids(num_target_layers, num_draft_layers): | ||
| """Select layers uniformly from the target model for feature extraction.""" | ||
| if num_draft_layers == 1: | ||
| return [num_target_layers // 2] | ||
| start = 1 | ||
| end = num_target_layers - 3 | ||
| span = end - start | ||
| return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] |
There was a problem hiding this comment.
Clamp target-layer selection for tiny backbones.
For num_target_layers < 4, Line 53 makes end negative. A 2-layer target with 2 draft layers returns [1, -1], and the later lid + offset at Line 608 then pulls the embedding state for the -1 slot instead of a decoder layer. Please special-case small models or clamp the result so every id stays in [0, num_target_layers - 1].
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 48 - 55, The
build_target_layer_ids function can produce negative or out-of-range indices for
tiny backbones; change it to special-case small models and/or clamp every
returned id into the valid range [0, num_target_layers - 1]. Specifically, in
build_target_layer_ids ensure when num_target_layers < 4 you return safe indices
(e.g., center or 0..n-1) and after computing the list, map each id to max(0,
min(id, num_target_layers - 1)) so downstream logic that uses lid + offset (the
decoder embedding lookup) never receives a negative or >=num_target_layers
index.
| # 1. Run base model → hidden states | ||
| with torch.no_grad(): | ||
| base_outputs = super().forward( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| output_hidden_states=True, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '580,620p'Repository: NVIDIA/Model-Optimizer
Length of output: 1922
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '1,80p'Repository: NVIDIA/Model-Optimizer
Length of output: 3962
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '500,580p'Repository: NVIDIA/Model-Optimizer
Length of output: 3866
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '400,500p'Repository: NVIDIA/Model-Optimizer
Length of output: 5296
🏁 Script executed:
rg "\.eval\(\)" modelopt/torch/speculative/plugins/hf_dflash.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg "\.train\(\)" modelopt/torch/speculative/plugins/hf_dflash.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '300,400p'Repository: NVIDIA/Model-Optimizer
Length of output: 4738
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '200,300p'Repository: NVIDIA/Model-Optimizer
Length of output: 5251
🏁 Script executed:
rg -A 20 "class DFlashModel" modelopt/torch/speculative/dflash/dflash_model.py | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 877
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '595,650p'Repository: NVIDIA/Model-Optimizer
Length of output: 3089
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '1,100p' | grep -E "(\.eval|\.train|training)"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg "self.training" modelopt/torch/speculative/plugins/hf_dflash.pyRepository: NVIDIA/Model-Optimizer
Length of output: 169
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '468,482p'Repository: NVIDIA/Model-Optimizer
Length of output: 652
🏁 Script executed:
rg -B 5 -A 15 "torch.no_grad" modelopt/torch/speculative/plugins/hf_dflash.py | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 1961
Switch the teacher base model to eval mode for the forward pass.
torch.no_grad() disables gradients but does not disable dropout. Since the model is in training mode (self.training==True at this point), stochastic layers will be active, causing target_hidden to jitter across identical batches. This destabilizes distillation training. Before this forward pass, set the base model to eval mode using a context manager, then restore training mode afterward.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 599 - 605, The
forward call to the teacher/base model (super().forward) is invoked while
self.training may be True, so dropout remains active; wrap the base-model
forward in a context that sets the teacher to eval mode (e.g., call model.eval()
on the teacher/base instance) before calling super().forward to produce
base_outputs/target_hidden, and restore the original training mode afterwards
(use a try/finally or a small context manager) to ensure deterministic hidden
states during distillation while not changing the overall module training flag.
- Use Qwen3 components directly (no dynamic _resolve_model_components) - Add sliding window attention support (config.layer_types) - Move rotary meta buffer fix to DFlashModule._apply() with detailed docs - Remove DFlash-specific resume code from main.py (standard resume works) - Remove unused train_dflash.py and ar_validate.sh - Simplify online_training.sh: direct accelerate launch, no arg parsing - YAML uses OmegaConf overrides directly (matching eagle3 pattern) - Update README to point to launcher example - Add extension docs for MoE and MLA support Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Revert on_step_end AR validation to upstream (DFlash deadlocks with DDP) - Revert checkpoint resume to upstream (load from checkpoint directly) - Keep: answer_only_loss pass-through, accuracy console/tensorboard logging - Document sliding window support in README and recipe YAML Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (4)
tools/launcher/common/dflash/online_training.sh (1)
42-42:⚠️ Potential issue | 🟠 MajorQuote the package constraint to avoid shell redirection.
Line 42 is parsed by Bash as redirection (
>), so pip may not receive the>=1.2.1constraint.Minimal fix
-pip install huggingface-hub>=1.2.1 +pip install "huggingface-hub>=1.2.1"#!/bin/bash set -euo pipefail target="tools/launcher/common/dflash/online_training.sh" echo "Inspect current command:" nl -ba "$target" | sed -n '40,44p' echo "Reproduce Bash parsing safely in temp dir:" tmpdir="$(mktemp -d)" ( cd "$tmpdir" mkdir -p bin cat > bin/pip <<'EOF' #!/bin/bash printf 'pip args:\n' for a in "$@"; do printf ' [%s]\n' "$a"; done EOF chmod +x bin/pip PATH="$tmpdir/bin:$PATH" bash -lc 'pip install huggingface-hub>=1.2.1' || true echo "Created files (redirection artifact expected):" ls -la ) rm -rf "$tmpdir"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/launcher/common/dflash/online_training.sh` at line 42, The pip install line in online_training.sh currently uses an unquoted version specifier (`pip install huggingface-hub>=1.2.1`) which Bash will parse as a redirection; update that command in tools/launcher/common/dflash/online_training.sh to quote the constraint (e.g. use pip install 'huggingface-hub>=1.2.1' or pip install "huggingface-hub>=1.2.1") so the version operator is passed to pip rather than treated as shell redirection.modelopt/torch/speculative/plugins/hf_dflash.py (3)
843-856:⚠️ Potential issue | 🟡 MinorDebug block fails for batch size > 1.
base_token.item()on line 852 raises an error whenbase_tokenhas shape[B, 1]withB > 1. This breaks the first batched call topseudo_speculative_generate().Proposed fix
if not hasattr(self, "_psg_debug"): self._psg_debug = True sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] th_dbg = torch.cat(sel, dim=-1) n_layers = len(base_outputs.hidden_states) th_norm = th_dbg.norm().item() print( f"[psg] hidden layers: {n_layers}, target_hidden: {th_dbg.shape}, norm: {th_norm:.2f}" ) - print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") + print(f"[psg] base_token: {base_token.squeeze().tolist()}, mask_token_id: {self.mask_token_id}") seq_len = input_ids.shape[1] blk = self.dflash_block_size print(f"[psg] pos: ctx=[0..{seq_len - 1}], blk=[{seq_len}..{seq_len + blk - 1}]")Or consider removing the debug prints entirely and using
logging.debug()if needed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 843 - 856, The debug block guarded by self._psg_debug calls base_token.item(), which fails for batch sizes >1 in pseudo_speculative_generate; change the debug to handle batched base_token (e.g., log a representative value or use base_token.tolist()/base_token.flatten() or base_token[0].item()) or remove the prints and use logging.debug; update the debug block around _psg_debug in hf_dflash.py (references: _psg_debug, base_token, pseudo_speculative_generate, target_layer_ids, dflash_block_size) so it no longer calls .item() on a batched tensor and safely formats the token(s) for any batch size.
74-82:⚠️ Potential issue | 🟡 MinorEdge case: negative or out-of-order indices for small models.
For
num_target_layers < 4,endbecomes negative or zero, producing potentially invalid or reversed layer indices. For example, withnum_target_layers=2andnum_draft_layers=2, this returns[1, -1].Consider adding bounds clamping or a special case for small models:
Proposed fix
def build_target_layer_ids(num_target_layers, num_draft_layers): """Select layers uniformly from the target model for feature extraction.""" + if num_target_layers < 4: + # For tiny models, return evenly spaced indices within valid range + return [i * (num_target_layers - 1) // max(num_draft_layers - 1, 1) + for i in range(min(num_draft_layers, num_target_layers))] if num_draft_layers == 1: return [num_target_layers // 2] start = 1 end = num_target_layers - 3 span = end - start - return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] + ids = [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] + return [max(0, min(i, num_target_layers - 1)) for i in ids]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 74 - 82, The function build_target_layer_ids can produce negative or out-of-range indices for small models; update it to handle small/edge cases by adding guards: if num_target_layers <= 0 or num_draft_layers <= 0 return an empty list; if num_target_layers < 4 treat target layers as 0..num_target_layers-1 and clip selection accordingly (e.g., spread indices across that valid range); compute start = max(0, 1) and end = max(0, num_target_layers-1) (or special-case when num_draft_layers == 1 to return the middle valid index), generate indices with the existing uniform formula but clamp each resulting index to the inclusive range [0, num_target_layers-1] before returning to ensure no negative or out-of-order values from build_target_layer_ids.
649-656:⚠️ Potential issue | 🟠 MajorTeacher forward should use eval mode to disable dropout.
The base model forward runs under
torch.no_grad(), which disables gradient computation but not dropout. Whenself.training=True, any dropout layers in the base model remain active, causingtarget_hiddento vary stochastically across identical inputs. This can destabilize distillation training.Proposed fix
# 1. Run base model → hidden states + # Temporarily switch to eval to disable dropout in teacher + was_training = self.training with torch.no_grad(): + self.eval() base_outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) + if was_training: + self.train()Note: Alternatively, only set eval mode on the base model submodules rather than
selfto avoid affecting dflash_module's training state.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 649 - 656, The base model forward currently runs under torch.no_grad() but not eval(), so dropout can remain active; wrap the call to super().forward(...) in evaluation mode (call eval() on the base model submodule) and restore its original training state after the call so only the base model is in eval and not the whole dflash module—i.e., around the super().forward(...) invocation (which produces base_outputs/target_hidden from input_ids and attention_mask with output_hidden_states=True) set the base model to eval, call super().forward(...), then revert the base model to its prior training flag to avoid affecting self.training.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/hf_dflash.py (1)
516-516: Consider using logging instead of print statements.Debug information like
mask_token_idandoriginal_cls(lines 516, 559) should use theloggingmodule for better control over verbosity levels in production.Proposed fix
+import logging + +logger = logging.getLogger(__name__) + # In modify(): - print(f"DFlash mask_token_id: {self.mask_token_id}") + logger.info(f"DFlash mask_token_id: {self.mask_token_id}") ... - print(f"DFlash: using {original_cls.__name__}.forward as base forward") + logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 516, Replace ad-hoc print statements that output debug info (e.g., the print of self.mask_token_id and the one referencing self.original_cls) with the Python logging module: add import logging and a module-level logger = logging.getLogger(__name__), then change the prints to logger.debug (or logger.info if more appropriate) inside the class/method where they occur (references: self.mask_token_id and self.original_cls) so verbosity can be controlled in production.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 77-87: When NUM_NODES != "1" ensure HEAD_NODE_IP is validated
before building MULTI_NODE_ARGS: check that HEAD_NODE_IP is non-empty and a
plausible IP/hostname (non-empty string) after the auto-detection logic, and if
it is empty print a clear error message (including value of NUM_NODES and hint
about missing detection/SLURM variables) to stderr and exit with non-zero status
to fail fast; update the multi-node branch that constructs MULTI_NODE_ARGS
(referencing MULTI_NODE_ARGS, NUM_NODES, GPU_PER_NODE, SLURM_PROCID, and
HEAD_NODE_IP) to perform this validation immediately before using HEAD_NODE_IP.
- Around line 41-47: Move the failure handling and enable fail-fast before any
package installs: add "set -e" (or "set -o errexit") and install the trap
invocation for error_handler (the line "trap 'error_handler $0 $LINENO' ERR")
before the two pip install lines (the commands that install requirements and
huggingface-hub). This ensures any failure in the pip install commands triggers
error_handler immediately and the script exits instead of continuing into
training with partial dependencies.
---
Duplicate comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 843-856: The debug block guarded by self._psg_debug calls
base_token.item(), which fails for batch sizes >1 in
pseudo_speculative_generate; change the debug to handle batched base_token
(e.g., log a representative value or use
base_token.tolist()/base_token.flatten() or base_token[0].item()) or remove the
prints and use logging.debug; update the debug block around _psg_debug in
hf_dflash.py (references: _psg_debug, base_token, pseudo_speculative_generate,
target_layer_ids, dflash_block_size) so it no longer calls .item() on a batched
tensor and safely formats the token(s) for any batch size.
- Around line 74-82: The function build_target_layer_ids can produce negative or
out-of-range indices for small models; update it to handle small/edge cases by
adding guards: if num_target_layers <= 0 or num_draft_layers <= 0 return an
empty list; if num_target_layers < 4 treat target layers as
0..num_target_layers-1 and clip selection accordingly (e.g., spread indices
across that valid range); compute start = max(0, 1) and end = max(0,
num_target_layers-1) (or special-case when num_draft_layers == 1 to return the
middle valid index), generate indices with the existing uniform formula but
clamp each resulting index to the inclusive range [0, num_target_layers-1]
before returning to ensure no negative or out-of-order values from
build_target_layer_ids.
- Around line 649-656: The base model forward currently runs under
torch.no_grad() but not eval(), so dropout can remain active; wrap the call to
super().forward(...) in evaluation mode (call eval() on the base model
submodule) and restore its original training state after the call so only the
base model is in eval and not the whole dflash module—i.e., around the
super().forward(...) invocation (which produces base_outputs/target_hidden from
input_ids and attention_mask with output_hidden_states=True) set the base model
to eval, call super().forward(...), then revert the base model to its prior
training flag to avoid affecting self.training.
In `@tools/launcher/common/dflash/online_training.sh`:
- Line 42: The pip install line in online_training.sh currently uses an unquoted
version specifier (`pip install huggingface-hub>=1.2.1`) which Bash will parse
as a redirection; update that command in
tools/launcher/common/dflash/online_training.sh to quote the constraint (e.g.
use pip install 'huggingface-hub>=1.2.1' or pip install
"huggingface-hub>=1.2.1") so the version operator is passed to pip rather than
treated as shell redirection.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Line 516: Replace ad-hoc print statements that output debug info (e.g., the
print of self.mask_token_id and the one referencing self.original_cls) with the
Python logging module: add import logging and a module-level logger =
logging.getLogger(__name__), then change the prints to logger.debug (or
logger.info if more appropriate) inside the class/method where they occur
(references: self.mask_token_id and self.original_cls) so verbosity can be
controlled in production.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 9d2ccb5b-e089-4948-8b21-8672dcbead83
📒 Files selected for processing (5)
examples/speculative_decoding/README.mdexamples/speculative_decoding/main.pymodelopt/torch/speculative/plugins/hf_dflash.pytools/launcher/common/dflash/online_training.shtools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
✅ Files skipped from review due to trivial changes (2)
- examples/speculative_decoding/README.md
- tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/speculative_decoding/main.py
| if [[ "$NUM_NODES" != "1" ]]; then | ||
| GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} | ||
| TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) | ||
| echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" | ||
| MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ | ||
| --num_machines $NUM_NODES \ | ||
| --machine_rank $SLURM_PROCID \ | ||
| --rdzv_backend c10d \ | ||
| --main_process_ip $HEAD_NODE_IP \ | ||
| --main_process_port 29500" | ||
| else |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
target="tools/launcher/common/dflash/online_training.sh"
rg -n 'HEAD_NODE_IP|main_process_ip|NUM_NODES' "$target" -n -C3Repository: NVIDIA/Model-Optimizer
Length of output: 2312
Add validation for HEAD_NODE_IP in multi-node mode to fail fast.
The auto-detection logic (lines 49–72) attempts to resolve HEAD_NODE_IP but does not validate success. If all detection methods fail (e.g., in containerized environments or when SLURM variables are absent), the variable remains empty. The multi-node branch at line 77 then passes --main_process_ip $HEAD_NODE_IP to accelerate with an empty value, causing unclear failures downstream.
Suggested guard
if [[ "$NUM_NODES" != "1" ]]; then
+ if [[ -z "${HEAD_NODE_IP:-}" ]]; then
+ echo "[ERROR] HEAD_NODE_IP is empty for NUM_NODES=$NUM_NODES" >&2
+ exit 1
+ fi
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tools/launcher/common/dflash/online_training.sh` around lines 77 - 87, When
NUM_NODES != "1" ensure HEAD_NODE_IP is validated before building
MULTI_NODE_ARGS: check that HEAD_NODE_IP is non-empty and a plausible
IP/hostname (non-empty string) after the auto-detection logic, and if it is
empty print a clear error message (including value of NUM_NODES and hint about
missing detection/SLURM variables) to stderr and exit with non-zero status to
fail fast; update the multi-node branch that constructs MULTI_NODE_ARGS
(referencing MULTI_NODE_ARGS, NUM_NODES, GPU_PER_NODE, SLURM_PROCID, and
HEAD_NODE_IP) to perform this validation immediately before using HEAD_NODE_IP.
- Consolidate dflash_results.md into comprehensive dflash.md - Simplify ar_validate.py: online GT as default, per-category support - Simplify ar_eval_mtbench.sh: calls ar_validate.py instead of inline Python - Error on unsupported mask_token_id instead of falling back to pad/eos - Add sliding window, FP8/NVFP4, offline training, MLA docs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- export.sh: standalone checkpoint export to z-lab format - ptq_and_export.sh: FP8/NVFP4 quantization via hf_ptq.py - Fix rope_theta export (prefer draft_config over base_config) - Document vLLM integration gap, FP8/NVFP4 flow in dflash.md Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
DFlash (Block Diffusion for Flash Speculative Decoding) predicts an entire block of tokens in a single forward pass using masked parallel prediction with KV injection from the target model's hidden states.
Key features:
Training recipe: modelopt_recipes/general/speculative_decoding/dflash.yaml
Results: examples/speculative_decoding/doc/dflash_results.md
ModelOpt Eval (online validation, osl=512)
z-lab Official Eval (dflash.benchmark, osl=512)
Evaluation Method Impact (gsm8k)
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Documentation