From 5f8d004ec751e9bb8d49b2a214b07ce1019c9f60 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 8 Apr 2026 13:06:17 -0700 Subject: [PATCH 01/24] add: DFlash block diffusion speculative decoding DFlash (Block Diffusion for Flash Speculative Decoding) predicts an entire block of tokens in a single forward pass using masked parallel prediction with KV injection from the target model's hidden states. Key features: - Feature fusion (multi-layer hidden states -> FC + RMSNorm) - KV injection (fused features as K/V in every draft layer with QK-norm) - Random anchor sampling with bidirectional intra-block attention - Logit distillation with exponential loss decay (gamma weighting) - Multi-node DDP training with checkpoint resume - Export to z-lab compatible HF format - Online validation (context-dependent ground truth) Training recipe: modelopt_recipes/general/speculative_decoding/dflash.yaml Results: examples/speculative_decoding/doc/dflash_results.md Co-Authored-By: Claude Opus 4.6 (1M context) --- doc/results/dflash_results.html | 0 examples/speculative_decoding/README.md | 38 + .../doc/dflash_results.md | 85 ++ examples/speculative_decoding/eagle_utils.py | 69 +- examples/speculative_decoding/main.py | 85 +- .../scripts/export_hf_checkpoint.py | 10 +- examples/speculative_decoding/train_dflash.py | 319 +++++++ .../torch/export/plugins/hf_spec_export.py | 106 ++- modelopt/torch/speculative/config.py | 55 ++ modelopt/torch/speculative/dflash/__init__.py | 20 + .../torch/speculative/dflash/conversion.py | 58 ++ .../speculative/dflash/default_config.py | 28 + .../torch/speculative/dflash/dflash_model.py | 36 + modelopt/torch/speculative/mode.py | 31 +- .../torch/speculative/plugins/__init__.py | 3 + .../torch/speculative/plugins/hf_dflash.py | 886 ++++++++++++++++++ modelopt/torch/speculative/utils.py | 78 ++ .../utils/plugins/transformers_dataset.py | 212 ++++- .../general/speculative_decoding/dflash.yaml | 53 ++ .../speculative/plugins/test_hf_dflash.py | 196 ++++ .../speculative/plugins/test_hf_dflash.py | 243 +++++ .../launcher/common/dflash/ar_eval_mtbench.sh | 225 +++++ tools/launcher/common/dflash/ar_validate.sh | 127 +++ .../launcher/common/dflash/online_training.sh | 255 +++++ .../Qwen/Qwen3-8B/hf_online_dflash.yaml | 63 ++ 25 files changed, 3234 insertions(+), 47 deletions(-) create mode 100644 doc/results/dflash_results.html create mode 100644 examples/speculative_decoding/doc/dflash_results.md create mode 100644 examples/speculative_decoding/train_dflash.py create mode 100644 modelopt/torch/speculative/dflash/__init__.py create mode 100644 modelopt/torch/speculative/dflash/conversion.py create mode 100644 modelopt/torch/speculative/dflash/default_config.py create mode 100644 modelopt/torch/speculative/dflash/dflash_model.py create mode 100644 modelopt/torch/speculative/plugins/hf_dflash.py create mode 100644 modelopt_recipes/general/speculative_decoding/dflash.yaml create mode 100644 tests/gpu/torch/speculative/plugins/test_hf_dflash.py create mode 100644 tests/unit/torch/speculative/plugins/test_hf_dflash.py create mode 100644 tools/launcher/common/dflash/ar_eval_mtbench.sh create mode 100644 tools/launcher/common/dflash/ar_validate.sh create mode 100644 tools/launcher/common/dflash/online_training.sh create mode 100644 tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml diff --git a/doc/results/dflash_results.html b/doc/results/dflash_results.html new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 820b587b91..b51c90c55c 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -350,3 +350,41 @@ More models coming soon! - 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) - 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) - ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) + +## DFlash (Block Diffusion for Speculative Decoding) + +DFlash is a parallel speculative decoding method based on [Block Diffusion](https://arxiv.org/abs/2602.06036). +Unlike autoregressive draft models (EAGLE3), DFlash predicts an entire block of tokens in a single forward pass +using masked parallel prediction with KV injection from the target model's hidden states. + +### Quick Start + +```bash +./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/dflash.yaml \ + model.model_name_or_path=/path/to/Qwen3-8B \ + data.data_path=/path/to/train.jsonl \ + training.output_dir=/path/to/output +``` + +### Key Configuration (dflash.yaml) + +| Field | Default | Description | +|-------|---------|-------------| +| `dflash.dflash_block_size` | 8 | Block size for parallel prediction | +| `dflash.dflash_num_anchors` | 512 | Number of anchor positions per sample | +| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables) | +| `dflash.dflash_self_logit_distillation` | true | Use logit distillation from target | +| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | +| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | + +### Export + +```bash +python scripts/export_hf_checkpoint.py \ + --model_path /path/to/training/output \ + --export_path /path/to/exported/model +``` + +### Results + +See [doc/dflash_results.md](doc/dflash_results.md) for benchmark results on Qwen3-8B. diff --git a/examples/speculative_decoding/doc/dflash_results.md b/examples/speculative_decoding/doc/dflash_results.md new file mode 100644 index 0000000000..a6a4f2d252 --- /dev/null +++ b/examples/speculative_decoding/doc/dflash_results.md @@ -0,0 +1,85 @@ +# DFlash Block Diffusion — ModelOpt Training Results + +Qwen3-8B target model, trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples) + +## Key Metrics + +| Benchmark | Acceptance Rate | +|-----------|----------------| +| **gsm8k** | **5.19** | +| **MT-Bench** | **4.36** | + +> Online validation, block_size=8, osl=512 + +## Training Configuration + +| Parameter | Value | +|-----------|-------| +| Target Model | Qwen3-8B | +| Draft Layers | 5 | +| Block Size | 8 | +| Sequence Length | 4096 | +| Anchors per Sample | 512 | +| Loss | KD (logit distillation) + exponential decay (gamma=4) | +| Learning Rate | 6e-4 (linear decay) | +| Epochs | 10 | +| GPUs | 64 (8 nodes x 8 H100) | +| Total Steps | 306,620 | +| Final Loss | 1.129 | +| Final Per-Token Acc | 67.0% | + +## MT-Bench Per-Category AR (Online Validation) + +80 prompts, block_size=8, osl=512, steps=7 + +| Category | 80K | 150K | 306K (final) | +|----------|-----|------|-------------| +| math | 5.44 | 5.54 | **5.52** | +| extraction | 4.81 | 4.82 | **4.88** | +| coding | 4.40 | 4.53 | **4.60** | +| reasoning | 4.34 | 4.41 | **4.44** | +| stem | 4.05 | 4.15 | **4.17** | +| writing | 3.76 | 3.79 | **3.84** | +| roleplay | 3.58 | 3.73 | **3.78** | +| humanities | 3.55 | 3.62 | **3.65** | +| **ALL** | **4.24** | **4.32** | **4.36** | + +## Comparison with z-lab/Qwen3-8B-DFlash-b16 + +### ModelOpt Eval (online validation, osl=512) + +| Dataset | z-lab | ModelOpt (306K) | Diff | +|---------|-------|-----------------|------| +| gsm8k | 4.10 | **5.19** | **+1.09** | +| MT-Bench | 3.58 | **4.36** | **+0.78** | + +### z-lab Official Eval (dflash.benchmark, osl=512) + +| Dataset | z-lab | ModelOpt (306K) | Diff | +|---------|-------|-----------------|------| +| gsm8k | **5.00** | 4.08 | -0.92 | +| MT-Bench | **3.28** | 2.99 | -0.29 | + +> z-lab model trained with block_size=16. ModelOpt trained with block_size=8. + +## Evaluation Method Impact (gsm8k) + +| Eval Method | z-lab checkpoint | ModelOpt (306K) | +|-------------|-----------------|-----------------| +| Fixed GT (ModelOpt eval) | 2.95 | 4.23 | +| Online GT (ModelOpt eval) | 4.10 | **5.19** | +| z-lab official eval | **5.00** | 4.08 | + +- **Fixed GT**: pre-compute greedy ground truth, check draft against it. +- **Online GT**: recompute ground truth after each accepted draft (context-dependent). +- **z-lab official**: actual speculative decoding with draft KV cache. + +## Key Findings + +| Finding | Evidence | +|---------|----------| +| Loss decay boosts AR | +0.12 AR at 55K steps (gamma=7, bs16); consistent across all checkpoints | +| Longer sequences help | seq=4096 vs 512: +0.49 AR on AA-Synthetic at same checkpoint | +| Online validation essential | Fixed GT underestimates by ~1.0 AR; context-dependent GT matches actual spec-decode | +| Forward pass identical to z-lab | Max diff 0.5 (bf16 noise) on same mask_token_id; 6/7 draft tokens match | +| sdpa vs flash_attn: negligible | Overall AR 3.31 vs 3.31; hidden states identical, logits differ <2% | diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 99c8ef4e03..2b01239cb4 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -141,6 +141,7 @@ def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, train_len=None, + answer_only_loss=False, ) -> dict: if data_args.offline_data_path is None: train_dataset = ShardedDataset("json", data_files=data_args.data_path) @@ -150,6 +151,7 @@ def make_eagle_supervised_data_module( tokenizer=tokenizer, train_len=train_len, return_labels=True, + answer_only_loss=answer_only_loss, ) else: data_collator = VisionLanguageDataCollator( @@ -205,6 +207,12 @@ def on_log(self, args, state, control, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: return control average_acc = np.mean(state.training_accs, axis=0) + # Always print accuracy to console + try: + acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten()) + print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]") + except Exception: + print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}") if self.estimate_ar: # Calculate mean training AR since last log # NOTE: This is only an estimate of the real AR. @@ -219,41 +227,64 @@ def on_log(self, args, state, control, **kwargs): est_ar += acc_cumprod print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") + # Log accuracy to HF Trainer's logs dict (picked up by TensorBoard) + logs = kwargs.get("logs") or {} + for i, draft_acc in enumerate(average_acc): + for j, step_acc in enumerate(draft_acc): + logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc) + if self.estimate_ar: + logs["estimated_training_ar"] = est_ar + # log to wandb - if wandb and is_master(): - logs = kwargs.get("logs") or {} + if hasattr(wandb, "init") and is_master(): if logs: wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step) - for i, draft_acc in enumerate(average_acc): - for j, step_acc in enumerate(draft_acc): - wandb.log( - {f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step - ) - if self.estimate_ar: - wandb.log({"estimated_training_ar": est_ar}, step=state.global_step) # reset training_accs state.training_accs = [] return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically, if available.""" + """Run AR validation periodically (single-GPU only). + + AR validation with DDP is not supported because pseudo_speculative_generate + runs only on rank 0 while other ranks deadlock waiting for collective ops. + When world_size > 1, AR validation is skipped with a one-time warning. + Use post-training AR validation instead (online_training.sh runs it after training). + """ if self.ar_validate_steps <= 0: return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + if not hasattr(self, "_ar_ddp_warned"): + self._ar_ddp_warned = True + print_rank_0( + "=== WARNING === AR validation during training is not supported with " + "DDP (world_size > 1). Skipping. Use post-training AR validation." + ) + return control + + model = kwargs["model"] + raw_model = model.module if hasattr(model, "module") else model + was_training = raw_model.training + raw_model.eval() print_rank_0("Running AR validation...") try: - ars = validate_ar( - model=kwargs["model"], - tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], - device=kwargs["model"].device, - ) + with torch.no_grad(): + ars = validate_ar( + model=raw_model, + tokenizer=kwargs["processing_class"], + ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"], + device=next(raw_model.parameters()).device, + num_samples=8, + ) print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb and is_master(): + if wandb: wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - except Exception: - print_rank_0("AR validation not available.") + except Exception as e: + print_rank_0(f"AR validation failed: {e}") + if was_training: + raw_model.train() return control diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 694aa3303f..1bb70f871f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -103,7 +103,7 @@ class TrainingArguments(transformers.TrainingArguments): ) }, ) - mode: Literal["eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa", "dflash"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} ) @@ -133,8 +133,8 @@ def _parse_cli() -> tuple[str, list[str]]: return args.config, overrides -def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict]: - """Load training config from a YAML file with sections: model, data, training, eagle. +def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict, dict]: + """Load training config from a YAML file with sections: model, data, training, eagle/dflash. *overrides* are OmegaConf dotlist entries (e.g. ``["model.model_name_or_path=xxx"]``) applied on top of the YAML. @@ -142,15 +142,16 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic Returns: hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() + dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert() """ merged = OmegaConf.load(config_path) if overrides: merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides))) cfg = OmegaConf.to_container(merged, resolve=True) - # Eagle section maps directly to EagleConfig fields — no field enumeration needed. - # eagle_architecture_config is a nested dict and is included as-is. + # Eagle/DFlash sections map directly to config fields — no field enumeration needed. eagle_cfg = cfg.get("eagle", {}) + dflash_cfg = cfg.get("dflash", {}) hf_cfg = { **cfg.get("model", {}), @@ -162,12 +163,14 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic cp_size = hf_cfg.get("cp_size", 1) hf_cfg["dp_shard_size"] = torch.cuda.device_count() // cp_size - return hf_cfg, eagle_cfg + return hf_cfg, eagle_cfg, dflash_cfg def train(): + import json + config_path, overrides = _parse_cli() - hf_cfg, eagle_cfg = _load_config(config_path, overrides) + hf_cfg, eagle_cfg, dflash_cfg = _load_config(config_path, overrides) parser = transformers.HfArgumentParser( ( @@ -193,7 +196,10 @@ def train(): patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 training_args.parallelism_config.sp_backend = None - print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, eagle_cfg={eagle_cfg}") + print_rank_0( + f"arguments: {model_args}, {training_args}, {medusa_args}, " + f"eagle_cfg={eagle_cfg}, dflash_cfg={dflash_cfg}" + ) # Detect checkpoint to resume from last_checkpoint = ( @@ -209,12 +215,32 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: + # Prefer top-level output_dir, fall back to checkpoint subdir + model_load_path = training_args.output_dir + if not os.path.isfile(os.path.join(model_load_path, "model.safetensors")): + model_load_path = checkpoint + print_rank_0( + f"No model.safetensors in {training_args.output_dir}, " + f"loading from checkpoint: {model_load_path}" + ) with patch_transformers5_params_loading(): model = load_vlm_or_llm( - checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code + model_load_path, + torch_dtype="auto", + trust_remote_code=model_args.trust_remote_code, ) + # DFlash: re-create rotary embeddings with meta-tensor buffers on CPU. + # inv_freq is computed (not saved in checkpoints), stays on meta after restore. + if training_args.mode == "dflash": + for mod in model.modules(): + if hasattr(mod, "rotary_emb"): + rotary = mod.rotary_emb + if any(b.is_meta for b in rotary.buffers()): + cfg = getattr(rotary, "config", None) + if cfg is not None: + mod.rotary_emb = type(rotary)(config=cfg, device="cpu") tokenizer = transformers.AutoTokenizer.from_pretrained( - checkpoint, trust_remote_code=model_args.trust_remote_code + model_load_path, trust_remote_code=model_args.trust_remote_code ) else: # To avoid OOM for large models, we load and convert model on CPU first. @@ -251,13 +277,19 @@ def train(): ) model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") + elif training_args.mode == "dflash": + # dflash_cfg maps directly to DFlashConfig fields. + mtsp.convert(model, [("dflash", dflash_cfg)]) else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") - if training_args.mode == "eagle3": + if training_args.mode in ("eagle3", "dflash"): data_module = make_eagle_supervised_data_module( - tokenizer, data_args, train_len=training_args.training_seq_len + tokenizer, + data_args, + train_len=training_args.training_seq_len, + answer_only_loss=(training_args.mode == "dflash"), ) trainer = EagleTrainerWithAccLog( @@ -276,7 +308,34 @@ def train(): ) print_rank_0("Start training...") - trainer.train(resume_from_checkpoint=checkpoint) + if checkpoint and not os.path.isfile( + os.path.join(training_args.output_dir, "model.safetensors") + ): + # Resume from checkpoint subdir: try full resume first, fall back to + # partial resume (model weights + trainer state, fresh optimizer) if + # the optimizer state doesn't match. + try: + trainer.train(resume_from_checkpoint=checkpoint) + except ValueError as e: + if "parameter group" in str(e): + print_rank_0( + f"Optimizer state mismatch: {e}\n" + f"Resuming with fresh optimizer from {checkpoint}" + ) + state_file = os.path.join(checkpoint, "trainer_state.json") + if os.path.isfile(state_file): + state = json.load(open(state_file)) + resumed_step = state.get("global_step", 0) + resumed_max_steps = state.get("max_steps", -1) + print_rank_0(f"Resuming from step {resumed_step}/{resumed_max_steps}") + if resumed_max_steps > 0: + training_args.max_steps = resumed_max_steps + trainer.state = trainer.state.load_from_json(state_file) + trainer.train() + else: + raise + else: + trainer.train(resume_from_checkpoint=checkpoint) trainer.save_state() trainer.save_model(training_args.output_dir) diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index 2771ab1513..925f4b73d0 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -29,7 +29,6 @@ def parse_args(): description="Export a HF checkpoint (with ModelOpt state) for deployment." ) parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.") - parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code") parser.add_argument( "--export_path", type=str, default="Destination directory for exported files." ) @@ -39,10 +38,11 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -model = load_vlm_or_llm( - args.model_path, torch_dtype="auto", trust_remote_code=args.trust_remote_code -) +model = load_vlm_or_llm(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): - export_speculative_decoding(model, export_dir=args.export_path) + export_speculative_decoding( + model, + export_dir=args.export_path, + ) print(f"Exported checkpoint to {args.export_path}") diff --git a/examples/speculative_decoding/train_dflash.py b/examples/speculative_decoding/train_dflash.py new file mode 100644 index 0000000000..20be8a85f0 --- /dev/null +++ b/examples/speculative_decoding/train_dflash.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone DFlash training script using SpecForge's data pipeline. + +Uses SpecForge's tokenizer template + offset-mapping loss mask for data +preprocessing, and ModelOpt's DFlash module for the draft model. This +isolates data pipeline differences from model architecture differences. + +Usage: + torchrun --nproc_per_node=8 train_dflash.py \ + --model /path/to/Qwen3-8B \ + --data /path/to/train.jsonl \ + --chat-template qwen \ + --block-size 16 \ + --num-draft-layers 5 \ + --num-epochs 3 \ + --lr 1e-4 \ + --output-dir /path/to/output +""" + +import argparse +import math +import os + +import torch +import torch.distributed as dist +from datasets import load_dataset +from torch.utils.data import DataLoader, DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="DFlash training with SpecForge data pipeline") + parser.add_argument("--model", type=str, required=True, help="Target model path") + parser.add_argument("--data", type=str, required=True, help="Training data JSONL path") + parser.add_argument("--chat-template", type=str, default="qwen", help="Chat template name") + parser.add_argument("--block-size", type=int, default=16) + parser.add_argument("--num-draft-layers", type=int, default=5) + parser.add_argument("--mask-token-id", type=int, default=None) + parser.add_argument("--max-length", type=int, default=512) + parser.add_argument("--num-epochs", type=int, default=3) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--warmup-ratio", type=float, default=0.01) + parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--save-interval", type=int, default=0, help="0 = save at end only") + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num-ar-samples", type=int, default=20, help="AR validation samples") + return parser.parse_args() + + +def is_rank0(): + """Check if current process is rank 0.""" + return not dist.is_initialized() or dist.get_rank() == 0 + + +def print_rank0(msg): + """Print only on rank 0.""" + if is_rank0(): + print(msg, flush=True) + + +def build_dataset(tokenizer, data_path, chat_template_name, max_length): + """Build dataset using SpecForge's data pipeline. + + Uses SpecForge's GeneralParser to tokenize conversations with the + proper chat template and compute offset-mapping-based loss masks. + """ + from specforge.data.parse import GeneralParser + from specforge.data.template import TEMPLATE_REGISTRY + + template = TEMPLATE_REGISTRY.get(chat_template_name) + parser = GeneralParser(tokenizer, template) + + raw_dataset = load_dataset("json", data_files=data_path)["train"] + + processed = {"input_ids": [], "loss_mask": []} + skipped = 0 + for sample in raw_dataset: + convs = sample.get("conversations", sample.get("messages", [])) + if not convs: + skipped += 1 + continue + try: + input_ids, loss_mask = parser.parse(convs, max_length=max_length) + processed["input_ids"].append(input_ids) + processed["loss_mask"].append(loss_mask) + except Exception: + skipped += 1 + + print_rank0(f"Processed {len(processed['input_ids'])} samples, skipped {skipped}") + return processed + + +class DFlashDataset(torch.utils.data.Dataset): + """Simple dataset wrapping tokenized input_ids and loss_mask.""" + + def __init__(self, data): + self.input_ids = data["input_ids"] + self.loss_mask = data["loss_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "loss_mask": self.loss_mask[idx], + } + + +def collate_fn(batch): + """Collate batch of samples.""" + input_ids = torch.stack([b["input_ids"] for b in batch]) + loss_mask = torch.stack([b["loss_mask"] for b in batch]) + return {"input_ids": input_ids, "loss_mask": loss_mask} + + +def train(args): + """Main training loop.""" + # Init distributed + dist.init_process_group("nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + torch.manual_seed(args.seed) + mto.enable_huggingface_checkpointing() + + # Load model + print_rank0(f"Loading model: {args.model}") + model = AutoModelForCausalLM.from_pretrained( + args.model, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + + # Detect mask_token_id + mask_token_id = args.mask_token_id + if mask_token_id is None: + if hasattr(tokenizer, "mask_token_id") and tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + elif hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + mask_token_id = tokenizer.pad_token_id + else: + mask_token_id = tokenizer.eos_token_id + print_rank0(f"mask_token_id: {mask_token_id}") + + # Convert to DFlash + config = { + "dflash_block_size": args.block_size, + "dflash_use_torch_compile": False, + "dflash_architecture_config": { + "num_hidden_layers": args.num_draft_layers, + "mask_token_id": mask_token_id, + }, + } + mtsp.convert(model, [("dflash", config)]) + print_rank0( + f"DFlash module created: {sum(p.numel() for p in model.dflash_module.parameters()):,} params" + ) + + # Build dataset using SpecForge pipeline + print_rank0("Building dataset with SpecForge pipeline...") + data = build_dataset(tokenizer, args.data, args.chat_template, args.max_length) + + # Filter samples with too few loss tokens + min_loss_tokens = 2 * args.block_size + filtered_ids = [] + filtered_masks = [] + for i in range(len(data["input_ids"])): + if data["loss_mask"][i].sum() >= min_loss_tokens: + filtered_ids.append(data["input_ids"][i]) + filtered_masks.append(data["loss_mask"][i]) + print_rank0(f"After filtering: {len(filtered_ids)} samples (min {min_loss_tokens} loss tokens)") + data = {"input_ids": filtered_ids, "loss_mask": filtered_masks} + + dataset = DFlashDataset(data) + sampler = DistributedSampler(dataset, shuffle=True) + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=2, + pin_memory=True, + drop_last=True, + ) + + # Wrap with DDP + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + find_unused_parameters=True, + ) + raw_model = model.module + + # Optimizer — only train dflash_module + optimizer = torch.optim.AdamW( + [p for p in raw_model.dflash_module.parameters() if p.requires_grad], + lr=args.lr, + weight_decay=0.0, + ) + + # LR scheduler + steps_per_epoch = len(dataloader) + total_steps = args.num_epochs * steps_per_epoch + warmup_steps = int(total_steps * args.warmup_ratio) + + def lr_lambda(step): + if step < warmup_steps: + return step / max(warmup_steps, 1) + progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + print_rank0(f"Training: {total_steps} steps, {warmup_steps} warmup, {steps_per_epoch}/epoch") + + # Training loop + global_step = 0 + for epoch in range(args.num_epochs): + sampler.set_epoch(epoch) + model.train() + + for batch in dataloader: + input_ids = batch["input_ids"].to(device) + loss_mask = batch["loss_mask"].to(device) + + # Create labels from loss_mask: -100 for masked positions + labels = input_ids.clone() + labels[loss_mask == 0] = -100 + + output = model( + input_ids=input_ids, + attention_mask=torch.ones_like(input_ids), + labels=labels, + ) + + loss = output.loss + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + global_step += 1 + + if global_step % args.log_interval == 0: + acc = output.train_acc[0][0] if hasattr(output, "train_acc") else 0.0 + lr = scheduler.get_last_lr()[0] + print_rank0( + f"Step {global_step} | loss={loss.item():.4f} | acc={acc:.4f} | lr={lr:.2e}" + ) + + if args.save_interval > 0 and global_step % args.save_interval == 0: + if is_rank0(): + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + raw_model.save_pretrained(save_path) + print_rank0(f"Saved checkpoint: {save_path}") + + # Save final model + if is_rank0(): + os.makedirs(args.output_dir, exist_ok=True) + raw_model.save_pretrained(args.output_dir) + print_rank0(f"Saved final model: {args.output_dir}") + + dist.barrier() + + # AR validation on rank 0 + if is_rank0() and args.num_ar_samples > 0: + print_rank0("\n=== AR Validation ===") + model.eval() + from modelopt.torch.speculative.plugins.transformers import HFARValidation + + validator = HFARValidation(raw_model, tokenizer) + ds = load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"] + + ars = [] + for i in range(min(args.num_ar_samples, len(ds))): + prompt = ds[i]["prompt"][0] + chat = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + inp = tokenizer(text, return_tensors="pt").input_ids.to(device) + try: + _, ar = validator.validate(osl=32, input_ids=inp, steps=3) + ars.append(ar) + print_rank0(f" AR={ar:.2f} | {prompt[:60]}") + except Exception as e: + print_rank0(f" ERROR | {prompt[:60]}... | {e}") + + if ars: + avg = sum(ars) / len(ars) + print_rank0("\n==== DFlash AR Results ====") + print_rank0(f"Average AR: {avg:.4f}") + print_rank0(f"Min: {min(ars):.4f}, Max: {max(ars):.4f}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_args() + train(args) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index aca19a1580..28a858d18c 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -27,7 +27,7 @@ from .hf_spec_configs import kimik2_eagle_template_config, llama_eagle_template_config -ALL_SPEC_MODES = ["eagle"] +ALL_SPEC_MODES = ["eagle", "dflash"] LLAMA_EAGLE_SINGLE_LAYER = { "required": { @@ -243,3 +243,107 @@ def _extract_state_dict(self, full_state_dict: dict): export_sd.pop(f"parallel_draft_heads.medusa_heads.{i}.{j}.linear.bias") ) return export_sd + + +class DFlashExporter(SpeculativeDecodingExporter): + """Draft model exporter for DFlash. + + Exports in z-lab compatible format: + - model.safetensors: draft module weights (no prefix) + - config.json: Qwen3-style config with dflash_config field + """ + + def __init__(self, model: nn.Module): + """Initialize the DFlashExporter.""" + super().__init__(model) + + def _extract_state_dict(self, full_state_dict: dict): + """Extract DFlash module weights, stripping the dflash_module prefix.""" + export_sd = {} + for key, value in full_state_dict.items(): + if "dflash_module." in key: + export_key = key.split("dflash_module.", 1)[1] + # Skip rotary embedding buffers (not needed, recomputed) + if "rotary_emb" in export_key: + continue + export_sd[export_key] = value.clone() + return export_sd + + def _export_config(self): + """Build config.json matching z-lab DFlash format.""" + model = self.model + base_config = ( + getattr(model.config, "text_config", None) + or getattr(model.config, "llm_config", None) + or model.config + ) + draft_config = model.dflash_config + + config = { + "architectures": ["DFlashDraftModel"], + "model_type": getattr(base_config, "model_type", "qwen3"), + "block_size": model.dflash_block_size, + "dflash_config": { + "mask_token_id": model.mask_token_id, + "target_layer_ids": list(model.target_layer_ids), + }, + # Architecture dimensions + "hidden_size": getattr(draft_config, "hidden_size", base_config.hidden_size), + "num_hidden_layers": draft_config.num_hidden_layers, + "num_attention_heads": getattr( + draft_config, "num_attention_heads", base_config.num_attention_heads + ), + "num_key_value_heads": getattr( + draft_config, "num_key_value_heads", base_config.num_key_value_heads + ), + "head_dim": getattr( + draft_config, + "head_dim", + base_config.hidden_size // base_config.num_attention_heads, + ), + "intermediate_size": getattr( + draft_config, "intermediate_size", base_config.intermediate_size + ), + "hidden_act": getattr(draft_config, "hidden_act", "silu"), + "rms_norm_eps": getattr(draft_config, "rms_norm_eps", 1e-6), + "vocab_size": base_config.vocab_size, + "max_position_embeddings": getattr(base_config, "max_position_embeddings", 32768), + "initializer_range": getattr(base_config, "initializer_range", 0.02), + "attention_bias": getattr(draft_config, "attention_bias", False), + "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), + "rope_theta": getattr(base_config, "rope_theta", 1000000.0), + "rope_scaling": getattr(base_config, "rope_scaling", None), + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "num_target_layers": getattr(base_config, "num_hidden_layers", 36), + } + + # Add layer_types if present (Qwen3-style) + if hasattr(draft_config, "layer_types"): + config["layer_types"] = draft_config.layer_types + else: + config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers + + return config + + def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): + """Export the DFlash draft model to deployment format.""" + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + # Export state dict + full_sd = self.model.state_dict() + drafter_sd = self._extract_state_dict(full_sd) + if dtype is not None: + drafter_sd = {k: v.to(dtype) for k, v in drafter_sd.items()} + save_file(drafter_sd, f"{export_dir}/model.safetensors") + + # Export config + drafter_config = self._export_config() + with open(f"{export_dir}/config.json", "w") as f: + json.dump(drafter_config, f, indent=2) + + print( + f"Exported DFlash draft model: {len(drafter_sd)} tensors, " + f"config keys: {list(drafter_config.keys())[:5]}..." + ) diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 69491c6599..5202865efb 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -46,6 +46,61 @@ } +def _get_dflash_default_config(): + from .dflash.default_config import default_dflash_config + + return default_dflash_config + + +DFLASH_DEFAULT_CFG = { + "algorithm": "dflash", + "config": { + "dflash_architecture_config": {}, # merged with default at convert time + }, +} + + +class DFlashConfig(ModeloptBaseConfig): + """DFlash config for block-wise parallel speculative decoding.""" + + dflash_block_size: int = ModeloptField( + default=16, + description="Block size for parallel prediction. Draft predicts this many tokens per block.", + ) + + dflash_freeze_base_model: bool = ModeloptField( + default=True, description="Whether to freeze base model during DFlash module training." + ) + + dflash_self_logit_distillation: bool = ModeloptField( + default=True, description="Whether to use logit distillation from base model." + ) + + dflash_loss_decay_factor: float = ModeloptField( + default=0.0, + description="Gamma for exponential loss decay weighting (paper Eq.4). " + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.", + ) + + dflash_num_anchors: int = ModeloptField( + default=512, + description="Number of random anchor positions sampled per sequence during training.", + ) + + dflash_report_acc: bool = ModeloptField( + default=True, description="Whether to report eval accuracy." + ) + + dflash_architecture_config: dict = ModeloptField( + default={}, description="Config for the DFlash draft module architecture." + ) + + dflash_use_torch_compile: bool = ModeloptField( + default=True, + description="Whether to use torch.compile on DFlash forward/loss methods.", + ) + + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/__init__.py b/modelopt/torch/speculative/dflash/__init__.py new file mode 100644 index 0000000000..912b8d47a2 --- /dev/null +++ b/modelopt/torch/speculative/dflash/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash Optimization Method.""" + +from .conversion import * +from .default_config import * +from .dflash_model import * diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py new file mode 100644 index 0000000000..943be90ca0 --- /dev/null +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash conversion/restore utilities.""" + +from torch import nn + +from modelopt.torch.opt.conversion import ModelLikeModule +from modelopt.torch.opt.dynamic import _DMRegistryCls +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict + +from ..config import DFlashConfig + +DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry + + +def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: + """Convert the model to a DFlash model as per `config`.""" + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + original_cls = type(model) + if original_cls not in DFlashDMRegistry: + for cls in DFlashDMRegistry._registry: + if issubclass(original_cls, cls): + DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls]) + break + + # merge custom config with default config (lazy import to avoid circular) + from .default_config import default_dflash_config + + custom_config = config.dflash_architecture_config + config.dflash_architecture_config = {**default_dflash_config, **custom_config} + + dflash_model = DFlashDMRegistry.convert(model) + dflash_model.modify(config) + + metadata = {} + return dflash_model, metadata + + +def restore_dflash_model( + model: nn.Module, config: DFlashConfig, metadata: MetadataDict +) -> nn.Module: + """Function for restoring a previously converted model to a DFlash model.""" + assert not metadata, "No metadata expected!" + return convert_to_dflash_model(model, config)[0] diff --git a/modelopt/torch/speculative/dflash/default_config.py b/modelopt/torch/speculative/dflash/default_config.py new file mode 100644 index 0000000000..5536e0d4df --- /dev/null +++ b/modelopt/torch/speculative/dflash/default_config.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default DFlash architecture config. + +Model-specific settings (hidden_size, num_attention_heads, rope_*, etc.) +are inherited from the base model in HFDFlashModel.modify(). Only +DFlash-specific defaults are set here. +""" + +default_dflash_config = { + "num_hidden_layers": 5, + "rms_norm_eps": 1e-06, + "attention_bias": False, + "attention_dropout": 0.0, +} diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py new file mode 100644 index 0000000000..0a10f065eb --- /dev/null +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash model to support block-wise parallel speculative decoding.""" + +from modelopt.torch.opt.dynamic import DynamicModule + + +class DFlashModel(DynamicModule): + """Base DFlash Model.""" + + def _setup(self): + """Register temporary attributes for the DFlash module.""" + self._register_temp_attribute("dflash_module", None) + + def modify(self, config): + """Base DFlash Model modify function. Child class should implement the details.""" + self.dflash_block_size = config.dflash_block_size + self.dflash_freeze_base_model = config.dflash_freeze_base_model + self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_self_logit_distillation = config.dflash_self_logit_distillation + self.dflash_num_anchors = config.dflash_num_anchors + self.dflash_report_acc = config.dflash_report_acc + self.dflash_use_torch_compile = config.dflash_use_torch_compile diff --git a/modelopt/torch/speculative/mode.py b/modelopt/torch/speculative/mode.py index 866449e155..ae965354a9 100644 --- a/modelopt/torch/speculative/mode.py +++ b/modelopt/torch/speculative/mode.py @@ -23,7 +23,8 @@ _ModeRegistryCls, ) -from .config import EagleConfig, MedusaConfig +from .config import DFlashConfig, EagleConfig, MedusaConfig +from .dflash.conversion import convert_to_dflash_model, restore_dflash_model from .eagle.conversion import convert_to_eagle_model, restore_eagle_model from .medusa.conversion import convert_to_medusa_model, restore_medusa_model @@ -58,6 +59,34 @@ def restore(self) -> RestoreEntrypoint: return restore_medusa_model +@SpeculativeDecodingModeRegistry.register_mode +class DFlashModeDescriptor(ModeDescriptor): + """Class to describe the ``"dflash"`` mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "dflash" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return DFlashConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_dflash_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_dflash_model + + @SpeculativeDecodingModeRegistry.register_mode class EagleModeDescriptor(ModeDescriptor): """Class to describe the ``"eagle"`` mode. diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index 5e3f4bff2f..d59aed37d5 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -31,3 +31,6 @@ with import_plugin("transformers"): from .transformers import * + +with import_plugin("hf_dflash"): + from .hf_dflash import * diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py new file mode 100644 index 0000000000..2601fd3431 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -0,0 +1,886 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash speculative decoding plugin for HuggingFace models. + +Matches the reference SpecForge implementation (github.com/sgl-project/SpecForge PR #415). + +Architecture: +- Feature Fusion: multi-layer target hidden states → FC + RMSNorm +- KV Injection: fused features as K/V in every draft layer with QK-norm +- Parallel Drafting: mask_token_id for unknown positions, causal within blocks +- Loss: hard CE on input_ids[i] (position i predicts token i) + +Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +""" + +import importlib + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.utils import ModelOutput + +from ..dflash.conversion import DFlashDMRegistry +from ..dflash.dflash_model import DFlashModel + + +def _resolve_model_components(model_type): + """Resolve MLP, RMSNorm, RotaryEmbedding from the base model's transformers module. + + Falls back to Llama components if the model type is unknown. + """ + fallback = "llama" + model_type = model_type or fallback + try: + mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}") + except (ImportError, ModuleNotFoundError): + mod = importlib.import_module(f"transformers.models.{fallback}.modeling_{fallback}") + model_type = fallback + + prefix = model_type.capitalize() + # Handle multi-word model types (e.g., "qwen3" -> "Qwen3") + for attr in dir(mod): + if attr.lower() == f"{model_type}mlp": + prefix = attr.replace("MLP", "") + break + + mlp_cls = getattr(mod, f"{prefix}MLP", None) + norm_cls = getattr(mod, f"{prefix}RMSNorm", None) + rotary_cls = getattr(mod, f"{prefix}RotaryEmbedding", None) + rotate_half_fn = getattr(mod, "rotate_half", None) + + # Fallback to Llama if any component is missing + if not all([mlp_cls, norm_cls, rotary_cls, rotate_half_fn]): + from transformers.models.llama.modeling_llama import ( + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + ) + from transformers.models.llama.modeling_llama import rotate_half as _rotate_half + + mlp_cls = mlp_cls or LlamaMLP + norm_cls = norm_cls or LlamaRMSNorm + rotary_cls = rotary_cls or LlamaRotaryEmbedding + rotate_half_fn = rotate_half_fn or _rotate_half + + return mlp_cls, norm_cls, rotary_cls, rotate_half_fn + + +# Default to Llama components; overridden per-model during convert() +_MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components("llama") + +__all__ = ["HFDFlashModel"] + + +def build_target_layer_ids(num_target_layers, num_draft_layers): + """Select layers uniformly from the target model for feature extraction.""" + if num_draft_layers == 1: + return [num_target_layers // 2] + start = 1 + end = num_target_layers - 3 + span = end - start + return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] + + +def apply_rotary_pos_emb(q, k, cos, sin): + """Apply RoPE. Q uses last q_len positions, K uses all positions.""" + cos = cos.unsqueeze(1) # [B, 1, seq, dim] + sin = sin.unsqueeze(1) + q_len = q.size(2) + q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class DFlashAttention(nn.Module): + """Attention with KV injection, using HF's attention dispatch for exact SpecForge parity.""" + + def __init__(self, config, layer_idx): + """Initialize DFlash attention with KV injection projections and QK-norm.""" + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + self.is_causal = False + + attn_bias = getattr(config, "attention_bias", False) + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias) + self.k_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias) + + self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + + # Resolve HF attention function matching SpecForge's dispatch + self._attn_fn = None + self.sliding_window = None + + def _get_attn_fn(self): + """Lazily resolve the HF attention function.""" + if self._attn_fn is not None: + return self._attn_fn + try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + impl = getattr(self.config, "_attn_implementation", "eager") + if impl and impl != "eager" and impl in ALL_ATTENTION_FUNCTIONS: + self._attn_fn = ALL_ATTENTION_FUNCTIONS[impl] + else: + self._attn_fn = self._eager_attention + except (ImportError, AttributeError): + self._attn_fn = self._eager_attention + return self._attn_fn + + def _eager_attention(self, module, q, k, v, attention_mask, **kwargs): + """Eager attention matching HF's eager_attention_forward.""" + scaling = kwargs.get("scaling", self.scaling) + n_rep = self.num_key_value_groups + if n_rep > 1: + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + attn_weights = torch.matmul(q, k.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + q.dtype + ) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward with KV injection: Q from noise, K/V from context+noise.""" + bsz, q_len, _ = hidden_states.shape + ctx_len = target_hidden.shape[1] + + # Q from noise only, with QK-norm + q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + + # K from context + noise, with QK-norm + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + + # V from context + noise (no norm) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + v = ( + torch.cat([v_ctx, v_noise], dim=1) + .view(bsz, ctx_len + q_len, -1, self.head_dim) + .transpose(1, 2) + ) + + # RoPE + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # Use HF's attention dispatch (handles GQA internally) + attn_fn = self._get_attn_fn() + attn_output, _ = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + return self.o_proj(attn_output) + + +class DFlashDecoderLayer(nn.Module): + """Draft decoder layer with KV injection.""" + + def __init__(self, config, layer_idx): + """Initialize decoder layer with attention, MLP, and layer norms.""" + super().__init__() + self.self_attn = DFlashAttention(config, layer_idx) + self.mlp = _MLP_CLS(config) + self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward pass with residual connections.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, target_hidden, position_embeddings, attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DFlashModule(nn.Module): + """DFlash draft module matching SpecForge DFlashDraftModel.""" + + def __init__(self, config): + """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" + super().__init__() + self.config = config + self.block_size = config.block_size + + # Feature fusion + num_fused_layers = len(config.target_layer_ids) + self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + # Decoder layers + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = _ROTARY_CLS(config=config) + self._rotary_config = config # Stored for re-creating rotary_emb on resume + + # Initialize weights matching HF PreTrainedModel (normal_ with initializer_range) + # SpecForge's DFlashDraftModel uses Qwen3PreTrainedModel.post_init() which does this. + self._init_weights(config) + + def _init_weights(self, config): + """Initialize weights matching HF PreTrainedModel._init_weights.""" + std = getattr(config, "initializer_range", 0.02) + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): + """Forward matching SpecForge DFlashDraftModel.forward.""" + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + # Re-create rotary_emb on correct device if buffers are on meta (checkpoint resume) + if any(b.is_meta for b in self.rotary_emb.buffers()): + self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device=hidden_states.device) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) + + return self.norm(hidden_states) + + +def create_dflash_attention_mask( + seq_len, block_size, device, dtype +): # Legacy: used for inference only + """Create [L, 2L] attention mask matching SpecForge. + + Context (cols 0..L-1): Block B sees blocks 0..B-1 (strictly previous). + Noise (cols L..2L-1): causal within same block only. + """ + indices = torch.arange(seq_len, device=device) + block_ids = indices // block_size + + q_block_ids = block_ids.unsqueeze(1) # [L, 1] + k_block_ids = block_ids.unsqueeze(0) # [1, L] + + ctx_mask = k_block_ids < q_block_ids + same_block = q_block_ids == k_block_ids + causal = indices.unsqueeze(0) >= indices.unsqueeze(1) # matching SpecForge: j >= i + noise_mask = same_block & causal + + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) + + # Create in f32 then cast, matching SpecForge. This ensures masked + # positions get -inf in bf16 (f32 min overflows to -inf when cast), + # not the largest finite negative bf16 value. + full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=torch.float32) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) + full_mask = full_mask.to(dtype=dtype) + + return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] + + +def create_dflash_loss_mask(seq_len, block_size, device): # Legacy: used for inference only + """Create loss mask: exclude Block 0 and block starts.""" + positions = torch.arange(seq_len, device=device) + block_ids = positions // block_size + is_block_0 = block_ids == 0 + is_block_start = (positions % block_size) == 0 + return (~is_block_0 & ~is_block_start).float() + + +@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class HFDFlashModel(DFlashModel): + """DFlash Model matching SpecForge OnlineDFlashModel.""" + + @property + def _base_model(self): + return self.get_submodule(self.base_model_path) + + @property + def _base_model_embeddings(self): + return self.get_submodule(self.base_model_embeddings_path) + + @property + def _base_model_lm_head(self): + return self.get_submodule(self.base_model_lm_head_path) + + @property + def _base_llm_config(self): + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) + + @staticmethod + def _auto_detect_mask_token_id(base_config): + """Auto-detect an appropriate mask token ID for DFlash. + + Different model families use different strategies: + - Qwen3/3.5: built-in [MASK] token in vocabulary + - Llama3: reserved special tokens (128002 = reserved_special_token_0) + - Others: try tokenizer.mask_token_id, then fall back to pad/eos + """ + model_type = getattr(base_config, "model_type", "") + vocab_size = getattr(base_config, "vocab_size", 0) + + # Qwen3/3.5: known mask token positions + if "qwen3" in model_type.lower() or "qwen" in model_type.lower(): + # Qwen3 vocab has dedicated mask tokens + # Qwen3.5-4B: 248070, Qwen3-8B: similar range + # Heuristic: eos_token_id + some offset, or check known values + eos = getattr(base_config, "eos_token_id", None) + if isinstance(eos, list): + eos = eos[0] + if eos and vocab_size > 200000: + # Large Qwen vocab — mask token is typically near end of special tokens + # Known: Qwen3.5 eos=248044, mask=248070 (offset ~26) + # Try common offsets + for offset in [26, 25, 24]: + candidate = eos + offset + if candidate < vocab_size: + return candidate + # Fallback for smaller Qwen models + if vocab_size > 150000: + return vocab_size - 250 # heuristic for Qwen special token region + + # Llama3: use reserved_special_token_0 (128002) + if "llama" in model_type.lower(): + if vocab_size >= 128256: # Llama3 vocab size + return 128002 # <|reserved_special_token_0|> + + # Generic: try pad_token_id, then eos + pad_id = getattr(base_config, "pad_token_id", None) + eos_id = getattr(base_config, "eos_token_id", None) + if isinstance(eos_id, list): + eos_id = eos_id[0] + + # Prefer pad over eos (pad is less likely to interfere) + if pad_id is not None and pad_id != eos_id: + return pad_id + + # Last resort + return eos_id or 0 + + def _find_base_model_parts(self): + """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.""" + for name, paths in { + "base_model_path": ["model.language_model", "model", "backbone"], + "base_model_embeddings_path": [ + "model.embed_tokens", + "backbone.embeddings", + "model.language_model.embed_tokens", + ], + "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + }.items(): + for path in paths: + try: + submodule = self.get_submodule(path) + assert isinstance(submodule, torch.nn.Module) + setattr(self, name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") + + def modify(self, config): + """Initialize DFlash draft module.""" + super().modify(config) + + base_config = self._base_llm_config + self.dflash_config = PretrainedConfig.from_dict(config.dflash_architecture_config) + + # Inherit settings from base model, but only those NOT already in the user config. + # hidden_size and vocab_size MUST match. Others (heads, intermediate_size) can differ. + # This allows the draft model to have a different architecture than the base model. + self.dflash_config.hidden_size = base_config.hidden_size + self.dflash_config.vocab_size = base_config.vocab_size + + # These use base model defaults if not specified in dflash_architecture_config + for attr, default_from_base in [ + ("max_position_embeddings", True), + ("intermediate_size", True), + ("num_attention_heads", True), + ("num_key_value_heads", True), + ("hidden_act", True), + ("rope_theta", True), + ("rope_scaling", True), + ("rope_type", False), + ("position_embedding_type", False), + ("rope_interleaved", False), + ("rms_norm_eps", True), + ("attention_bias", False), + ("tie_word_embeddings", False), + ]: + if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: + if default_from_base and hasattr(base_config, attr): + setattr(self.dflash_config, attr, getattr(base_config, attr)) + + # Ensure required attrs have defaults + if not hasattr(self.dflash_config, "mlp_bias") or self.dflash_config.mlp_bias is None: + self.dflash_config.mlp_bias = False + + self.dflash_config.head_dim = getattr( + self.dflash_config, + "head_dim", + self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, + ) + self.dflash_config.block_size = self.dflash_block_size + # Default to sdpa, matching SpecForge's DFlashDraftModel(Qwen3PreTrainedModel) + # which resolves to sdpa via post_init() + if self.dflash_config._attn_implementation is None: + self.dflash_config._attn_implementation = "sdpa" + + # Target layer IDs + num_target_layers = base_config.num_hidden_layers + num_draft_layers = self.dflash_config.num_hidden_layers + self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + self.dflash_config.target_layer_ids = self.target_layer_ids + + # mask_token_id resolution order: + # 1. Explicit in dflash_architecture_config (user override) + # 2. Auto-detect from model vocabulary: + # - Qwen3/3.5: built-in [MASK] token + # - Llama3: reserved_special_token_0 (128002) + # - Others: tokenizer.mask_token_id + # 3. Fallback to pad_token_id or eos_token_id (suboptimal) + mask_id = config.dflash_architecture_config.get("mask_token_id", None) + if mask_id is None: + mask_id = self._auto_detect_mask_token_id(base_config) + self.mask_token_id = mask_id[0] if isinstance(mask_id, list) else mask_id + print(f"DFlash mask_token_id: {self.mask_token_id}") + + # Freeze base model + if self.dflash_freeze_base_model: + for param in self.parameters(): + param.requires_grad = False + + self._find_base_model_parts() + + # Resolve model-specific components (MLP, RMSNorm, RotaryEmbedding) + # from the base model's architecture for weight compatibility + global _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half + _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components( + getattr(base_config, "model_type", "llama") + ) + self.dflash_module = DFlashModule(self.dflash_config) + self.dflash_module.to(self._base_model.dtype).to( + next(self._base_model.layers[-1].parameters()).device + ) + + self.is_quantized = False + self._num_anchors = self.dflash_num_anchors + + # Store bound reference to the original model class's forward. + # DynamicModule changes type(self) but the original class is in _original_cls. + # Find the original HF model class (e.g., Qwen3_5ForConditionalGeneration) + # by walking MRO and skipping DFlash/DynamicModule classes + skip_names = { + "HFDFlashModel", + "DFlashModel", + "DynamicModule", + "DFlashPreTrainedModel", + "DFlashDraftModel", + } + original_cls = None + for cls in type(self).__mro__: + if ( + hasattr(cls, "forward") + and cls.__name__ not in skip_names + and cls is not type(self) + and issubclass(cls, PreTrainedModel) + and cls is not PreTrainedModel + ): + original_cls = cls + break + if original_cls is None: + # Last resort: use the class two levels up (skip DFlash wrapper + DynamicModule) + original_cls = type(self).__mro__[2] + self._original_forward_cls = original_cls + print(f"DFlash: using {original_cls.__name__}.forward as base forward") + + def get_exporter(self): + """Get the exporter for the DFlash draft model.""" + from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter + + return DFlashExporter(self) + + def _base_forward(self, **kwargs): + """Call the original model's forward, bypassing DFlash wrapper.""" + return self._original_forward_cls.forward(self, **kwargs) + + def _sample_anchor_positions(self, seq_len, loss_mask, device): + """Randomly sample anchor positions per sample, matching SpecForge PR #473. + + Returns (anchor_positions [B, N], block_keep_mask [B, N]). + """ + bs = self.dflash_block_size + bsz = loss_mask.shape[0] + max_anchor = max(seq_len - bs, 0) + num_anchors = getattr(self, "_num_anchors", 512) + + valid = loss_mask[:, : max_anchor + 1] > 0.5 + valid_counts = valid.sum(dim=1) + max_n = min(num_anchors, int(valid_counts.max().item()) - 1) + + if max_n <= 0: + # No valid anchors — return empty + anchors = torch.zeros(bsz, 1, dtype=torch.long, device=device) + keep = torch.zeros(bsz, 1, dtype=torch.bool, device=device) + return anchors, keep + + indices = torch.arange(max_anchor + 1, device=device).unsqueeze(0).expand(bsz, -1) + masked_indices = torch.where(valid, indices, torch.tensor(seq_len + 1, device=device)) + + random_vals = torch.rand(bsz, max_anchor + 1, device=device) + random_vals = torch.where(valid, random_vals, torch.tensor(2.0, device=device)) + + _, sorted_idx = random_vals.sort(dim=1) + gathered = torch.gather(masked_indices, 1, sorted_idx) + anchors = gathered[:, :max_n].sort(dim=1).values + + keep = torch.arange(max_n, device=device).unsqueeze(0) < valid_counts.unsqueeze(1).clamp( + max=max_n + ) + anchors = torch.where(keep, anchors, torch.tensor(0, dtype=torch.long, device=device)) + return anchors, keep + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """Training forward matching SpecForge latest (post-PR #473). + + Key changes from original PR #415: + - Random anchor sampling instead of uniform block division + - Bidirectional intra-block attention (no causal constraint) + - Context sees strictly before anchor position + - Label alignment: position k predicts token at anchor+k + - Optional loss decay weighting + """ + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + device = input_ids.device + + # 1. Run base model → hidden states + with torch.no_grad(): + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + + # 2. Build loss mask from labels or attention_mask + if labels is not None: + loss_mask = (labels != -100).float() + elif attention_mask is not None: + loss_mask = attention_mask.float() + else: + loss_mask = torch.ones(bsz, seq_len, device=device) + + # 3. Random anchor sampling (SpecForge PR #463/#473) + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + if n_blocks == 0 or not block_keep_mask.any(): + # Zero loss that still flows through dflash_module for DDP gradient sync + dummy = self.dflash_module.fc.weight.sum() * 0.0 + return ModelOutput(loss=dummy, logits=base_outputs.logits, train_acc=[[0.0]]) + + # 4. Create noise embeddings: anchor token at block start, mask_token elsewhere + noise_ids = torch.full( + (bsz, n_blocks * block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_starts = torch.arange(n_blocks, device=device) * block_size + block_starts_exp = block_starts.unsqueeze(0).expand(bsz, -1) + valid_anchors = anchor_positions.clamp(0, seq_len - 1) + anchor_tokens = torch.gather(input_ids, 1, valid_anchors) + batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand(bsz, n_blocks) + noise_ids[batch_idx, block_starts_exp] = torch.where( + block_keep_mask, + anchor_tokens, + torch.tensor(self.mask_token_id, dtype=torch.long, device=device), + ) + noise_embedding = self._base_model_embeddings(noise_ids) + + # 5. Position IDs: context [0..S-1], draft blocks [anchor+0..anchor+B-1] + ctx_pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + offsets = torch.arange(block_size, device=device).view(1, 1, -1) + draft_pos = (anchor_positions.unsqueeze(-1) + offsets).view(bsz, -1) + full_pos = torch.cat([ctx_pos, draft_pos], dim=1) + + # 6. Attention mask: SDPA bool mask [B, 1, Q_LEN, KV_LEN] + q_len = n_blocks * block_size + kv_len = seq_len + q_len + + q_indices = torch.arange(q_len, device=device).view(1, 1, -1, 1) + kv_indices = torch.arange(kv_len, device=device).view(1, 1, 1, -1) + q_block_ids = q_indices // block_size + + anchor_exp = anchor_positions.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) + + # Context: kv < S and kv < anchor + mask_ctx = (kv_indices < seq_len) & (kv_indices < anchor_exp) + # Draft: kv >= S and same block + is_draft = kv_indices >= seq_len + kv_block_ids = (kv_indices - seq_len) // block_size + mask_draft = is_draft & (q_block_ids == kv_block_ids) + # Valid block + valid_block = block_keep_mask.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) + + final_mask = (mask_ctx | mask_draft) & valid_block # [B, 1, Q, KV] + + # Convert bool mask to float additive mask for SDPA + dtype = target_hidden.dtype + attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=torch.float32) + attn_mask.masked_fill_(~final_mask, torch.finfo(torch.float32).min) + attn_mask = attn_mask.to(dtype=dtype) + + # 7. Draft forward + hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=full_pos, + attention_mask=attn_mask, + ) + + # 8. Loss: same-position prediction (position k predicts token at anchor+k) + logits = self._base_model_lm_head(hidden) + + label_offsets = torch.arange(0, block_size, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets + valid_label = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + + # Weight mask: valid block * in bounds * exclude anchor (pos 0) * loss_mask + weight_mask = block_keep_mask.unsqueeze(-1).expand(-1, -1, block_size).float() + weight_mask = weight_mask * valid_label.float() + pos_in_block = torch.arange(block_size, device=device).view(1, 1, -1) + weight_mask = weight_mask * (pos_in_block > 0).float() + + orig_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + weight_mask = weight_mask * orig_loss_mask + + binary_eval_mask = weight_mask.view(-1) + + # Optional loss decay + if self.dflash_loss_decay_factor > 0: + k = torch.arange(block_size, device=device).view(1, 1, -1) + decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) + weight_mask = weight_mask * decay + + # Cross entropy or logit distillation + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + flat_weights = weight_mask.view(-1) + + valid_count = flat_weights.sum() + 1e-6 + + if valid_count > 1.0: + if self.dflash_self_logit_distillation: + # Teacher logits at position p predict token p+1 (autoregressive). + # Draft position k predicts token at anchor+k (same position). + # So teacher logits for token anchor+k are at position anchor+k-1. + base_logits = base_outputs.logits # [B, seq, vocab] + teacher_indices = (safe_label_indices - 1).clamp(min=0) + teacher_logits = torch.gather( + base_logits.unsqueeze(1).expand(-1, n_blocks, -1, -1), + 2, + teacher_indices.unsqueeze(-1).expand(-1, -1, -1, base_logits.size(-1)), + ) # [B, N, block_size, vocab] + flat_teacher = teacher_logits.reshape(-1, base_logits.size(-1)).detach() + target_soft = torch.softmax(flat_teacher, dim=-1) + draft_logsoft = torch.log_softmax(flat_logits, dim=-1) + kd_loss = -(target_soft * draft_logsoft).sum(dim=-1) + loss = (kd_loss * flat_weights).sum() / valid_count + else: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + loss = (loss_per_token * flat_weights).sum() / valid_count + + with torch.no_grad(): + preds = flat_logits.argmax(dim=-1) + correct = (preds == flat_targets) & (binary_eval_mask > 0.5) + accuracy = correct.sum().float() / (binary_eval_mask.sum() + 1e-6) + accuracy = accuracy.item() + else: + loss = flat_logits.sum() * 0.0 + accuracy = 0.0 + + return ModelOutput( + loss=loss, + logits=base_outputs.logits, + train_acc=[[accuracy]], + ) + + @torch.no_grad() + def pseudo_speculative_generate(self, input_ids, steps=1): + """Generate draft tokens using one DFlash block. + + DFlash generates block_size-1 draft tokens in a single forward pass. + The `steps` parameter is used as the number of tokens to return + (capped at block_size-1). + + Returns: + base_token: Next token from base model [B, 1]. + draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None. + """ + # Call the base model's inner model directly (avoids DynamicModule dispatch) + model_output = self._base_model( + input_ids=input_ids, + output_hidden_states=True, + ) + # Compute logits via lm_head + base_logits = self._base_model_lm_head(model_output.last_hidden_state) + # Build output with hidden_states + base_outputs = ModelOutput( + logits=base_logits, + hidden_states=model_output.hidden_states, + ) + base_logits = base_outputs.logits + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + + if steps < 1: + return base_token, None + + # Extract target hidden states (raw, before FC projection) + hid_offset = 1 + if not hasattr(self, "_psg_debug"): + self._psg_debug = True + sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + th_dbg = torch.cat(sel, dim=-1) + n_layers = len(base_outputs.hidden_states) + th_norm = th_dbg.norm().item() + print( + f"[psg] hidden layers: {n_layers}, target_hidden: {th_dbg.shape}, norm: {th_norm:.2f}" + ) + print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") + seq_len = input_ids.shape[1] + blk = self.dflash_block_size + print(f"[psg] pos: ctx=[0..{seq_len - 1}], blk=[{seq_len}..{seq_len + blk - 1}]") + selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) + + block_size = self.dflash_block_size + bsz = input_ids.shape[0] + seq_len = input_ids.shape[1] + device = input_ids.device + + # Block: first token is base_token (anchor), rest are mask + block_ids = torch.full( + (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_ids[:, 0] = base_token.squeeze(-1) + noise_embedding = self._base_model_embeddings(block_ids) + + # Position IDs: training uses [0..L-1, 0..L-1] where noise positions + # mirror context positions. At inference, block predicts tokens at + # seq_len..seq_len+B-1, so noise positions continue from ctx_len. + ctx_len = target_hidden.shape[1] + ctx_positions = torch.arange(ctx_len, device=device) + block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) + pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) + + # No attention mask at inference — matching SpecForge's spec_generate + # which uses KV cache with no mask. All positions attend freely to + # context and each other within the block. + + # Draft forward + draft_hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + + # Logits on positions 1..block_size-1 (skip anchor at position 0) + draft_logits = self._base_model_lm_head(draft_hidden[:, 1:, :]) + draft_tokens = draft_logits.argmax(dim=-1) # [B, block_size-1] + + # Return up to `steps` tokens + num_tokens = min(steps, block_size - 1) + return base_token, draft_tokens[:, :num_tokens] diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 9e167c8dc9..ca8bbcd0af 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -376,6 +376,84 @@ def validate( return ground_truth, ar + def validate_online( + self, + osl, + prompt=None, + input_ids=None, + steps=1, + ): + """Validate AR with online (context-dependent) ground truth. + + Instead of pre-computing a fixed ground truth, this method verifies + draft tokens against the target model's response to the current + sequence (including previously accepted draft tokens). This matches + the actual speculative decoding verification loop. + + Args: + osl: output sequence length + prompt: text prompt (alternative to input_ids) + input_ids: tokenized input + steps: number of draft tokens per step + """ + if input_ids is None: + input_ids = self.tokenize(prompt) + + isl = input_ids.shape[1] + max_len = isl + osl + total_accepted = 0 + cnt = 0 + + while input_ids.shape[1] < max_len: + cnt += 1 + + # Generate base token + draft tokens + input_id, draft_tokens = self.model.pseudo_speculative_generate(input_ids, steps=steps) + draft_tokens = self.check_data_consistency_across_ranks(draft_tokens) + input_id = self.check_data_consistency_across_ranks(input_id) + + # Append base token + input_ids = torch.cat((input_ids, input_id), dim=-1) + + if draft_tokens is None or input_ids.shape[1] >= max_len: + total_accepted += 1 # base token + continue + + # Build candidate sequence with draft tokens appended + candidate = torch.cat((input_ids, draft_tokens), dim=-1) + + # Get target model's response to the candidate sequence + with torch.no_grad(): + target_output = self.model._base_model(candidate) + target_logits = self.model._base_model_lm_head(target_output.last_hidden_state) + # posterior[i] = target's prediction given candidate[:i+1] + # For positions where we placed draft tokens, compare + # target's prediction at position i-1 with draft token at i + posterior = target_logits.argmax(dim=-1) + + # Check acceptance: compare draft[i] with posterior at input_ids_len-1+i + accepted = 0 + pos = input_ids.shape[1] - 1 # position of base token in candidate + for i in range(draft_tokens.shape[-1]): + if pos + i >= candidate.shape[1] - 1: + break + if posterior[:, pos + i] == draft_tokens[:, i]: + accepted += 1 + input_ids = torch.cat((input_ids, draft_tokens[:, i : i + 1]), dim=-1) + else: + # Rejected — append target's token instead + input_ids = torch.cat((input_ids, posterior[:, pos + i : pos + i + 1]), dim=-1) + accepted += 1 # target's token counts + break + + if input_ids.shape[1] >= max_len: + break + + total_accepted += 1 + accepted # base token + accepted drafts + + ar = total_accepted / cnt if cnt > 0 else 0.0 + return input_ids, ar + @contextlib.contextmanager def temporary_set_config_value(config, field, value): diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e147ebf2c2..b9a5367cd9 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -153,6 +153,9 @@ def __init__( if self.tokenizer.chat_template is None: raise ValueError("No valid chat template!") + if self.answer_only_loss: + self._ensure_generation_tags() + def _post_process_tokenizer(self): if self.tokenizer.pad_token_id is None: print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.") @@ -171,6 +174,166 @@ def _post_process_chat_template(self): REMOVE_THINK_CHAT_TEMPLATE, "" ) + # Simplified chat templates with {% generation %} tags for answer_only_loss. + # + # PURPOSE: + # HuggingFace's return_assistant_tokens_mask requires {% generation %} / + # {% endgeneration %} tags in the Jinja chat template to identify which tokens + # belong to assistant responses. Many models (Qwen3, Llama3) ship without these + # tags. These simplified templates add them so that answer_only_loss works + # reliably without regex fallbacks. + # + # HOW IT WORKS: + # When answer_only_loss=True, _ensure_generation_tags() detects the model's + # template style (ChatML, Llama3) and replaces the tokenizer's chat_template + # with one of these simplified versions. The {% generation %} tags tell HF + # exactly which tokens are assistant content for loss masking. + # + # WHAT IS PRESERVED: + # - System / user / assistant role formatting (exact token match) + # - Multi-turn conversation structure + # - block injection on last assistant turn (Qwen3-style, chatml_think) + # - Content is output as-is — training data with blocks is handled correctly + # + # WHAT IS DROPPED (vs original model templates): + # - Tool call formatting (tool_call XML tags, function signatures) + # - Multi-step tool response handling + # - reasoning_content vs content splitting logic + # - enable_thinking parameter support + # - VLM/multimodal content handling + # + # LIMITATIONS: + # - Training data with tool_call messages will not be formatted correctly. + # Use the original template with manually added {% generation %} tags for + # tool-use training data. + # - The chatml_think variant adds \n\n\n\n only to the last + # assistant turn (matching Qwen3 behavior). Non-last turns without + # in their content will differ from the original template which also + # conditionally adds think wrappers based on multi-step reasoning context. + # - Only ChatML (<|im_start|>/<|im_end|>) and Llama3 + # (<|start_header_id|>/<|eot_id|>) styles are supported. Other template + # styles fall back to regex-based assistant span detection. + # + # TO USE A CUSTOM TEMPLATE INSTEAD: + # Pass chat_template= to LanguageDataCollator with your own template that + # includes {% generation %}...{% endgeneration %} around assistant content. + _GENERATION_TEMPLATES = { + # Basic ChatML without injection (Phi, older Qwen, generic ChatML) + "chatml": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), + # ChatML with wrapper on last assistant turn (Qwen3-style) + "chatml_think": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{% if loop.last and not message['content'].startswith('') %}" + "{{ '\n\n\n\n' }}" + "{% endif %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), + "llama3": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|start_header_id|>system<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% generation %}" + "{{ message['content'] }}{% endgeneration %}{{ '<|eot_id|>' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" + ), + } + + def _ensure_generation_tags(self): + """Ensure chat template has {% generation %} tags for answer_only_loss. + + If the template already has generation tags, no action taken. + Otherwise, detect the template style and replace with a simplified + version that includes proper generation tags. + """ + template = self.tokenizer.chat_template + if template is None: + return + + if "{% generation %}" in template or "{%generation%}" in template: + return + + # Detect template style and replace with generation-tagged version + old_template = template + if "<|im_start|>" in template and "<|im_end|>" in template: + # Check if original template injects (Qwen3-style) + style = "chatml_think" if "" in template else "chatml" + elif "<|start_header_id|>" in template and "<|eot_id|>" in template: + style = "llama3" + else: + print_rank_0( + "=== WARNING === Cannot auto-inject {% generation %} tags for this chat " + "template. answer_only_loss will not work correctly. Provide a template " + "with {% generation %} tags via the chat_template parameter." + ) + return + + new_template = self._GENERATION_TEMPLATES[style] + self.tokenizer.chat_template = new_template + + # Verify + try: + test_msgs = [ + [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + ] + result = self.tokenizer.apply_chat_template( + test_msgs, + return_dict=True, + return_assistant_tokens_mask=True, + padding=True, + return_tensors="pt", + ) + mask = result.get("assistant_masks", None) + if mask is not None and mask.any(): + print_rank_0( + f"Replaced chat template with {style} generation-tagged version " + f"for answer_only_loss." + ) + return + except Exception: + pass + + # Revert on failure + self.tokenizer.chat_template = old_template + print_rank_0( + f"=== WARNING === Failed to apply {style} generation template. " + "answer_only_loss will not work correctly." + ) + def _process_chat_sample(self, examples: list): tokenized_examples = self.tokenizer.apply_chat_template( examples, @@ -186,6 +349,20 @@ def _process_chat_sample(self, examples: list): input_ids = tokenized_examples["input_ids"] labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) labels[..., :-1] = input_ids[..., 1:] + if self.answer_only_loss: + if "assistant_masks" in tokenized_examples: + assistant_mask = tokenized_examples["assistant_masks"] + if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): + labels[assistant_mask == 0] = IGNORE_TOKEN_ID + else: + # All assistant content truncated or no assistant in batch — mask all + labels[:] = IGNORE_TOKEN_ID + else: + raise ValueError( + "answer_only_loss requires {% generation %} tags in the chat " + "template but assistant_masks was not returned by the tokenizer. " + "Ensure _ensure_generation_tags() ran successfully." + ) tokenized_examples["labels"] = labels return tokenized_examples @@ -211,15 +388,34 @@ def __call__(self, examples): batch.append(text) else: messages = example.get("messages", None) - if messages is None: - conversations = example.get("conversations", None) - if conversations is None: - raise ValueError( - "The sample must in either OpenAI messages format or ShareGPT conversations format." + conversations = example.get("conversations", None) + # Prefer whichever has an assistant turn for training + if messages and any(m.get("role") == "assistant" for m in messages): + batch.append(messages) + elif conversations: + converted = _sharegpt_to_openai_messages(conversations) + if not any(m.get("role") == "assistant" for m in converted): + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in conversations." ) - else: - messages = _sharegpt_to_openai_messages(conversations) - batch.append(messages) + continue + batch.append(converted) + elif messages: + if not any(m.get("role") == "assistant" for m in messages): + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in messages." + ) + continue + batch.append(messages) + else: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + + if not batch: + # All samples skipped — create a dummy batch with all-masked labels + # so the training step produces zero loss without crashing DDP + batch = [[{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]] # type: ignore[list-item] return self._process_chat_sample(batch) diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml new file mode 100644 index 0000000000..83fd54fb20 --- /dev/null +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -0,0 +1,53 @@ +# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI. + +# maps to ModelArguments (main.py) +model: + model_name_or_path: + trust_remote_code: false + use_fake_base_for_offline: false + +# maps to DataArguments (main.py) +data: + data_path: + offline_data_path: + +# maps to TrainingArguments (main.py) +training: + # --- commonly modified --- + mode: dflash + output_dir: + num_train_epochs: 10 + per_device_train_batch_size: 1 + learning_rate: 6.0e-4 + warmup_steps: 100 + training_seq_len: 4096 + logging_steps: 100 + save_steps: 5000 + cp_size: 1 + dp_shard_size: 1 + disable_tqdm: true + estimate_ar: false + ar_validate_steps: 0 + + # --- rarely modified --- + do_eval: false + lr_scheduler_type: linear + save_strategy: steps + weight_decay: 0.0 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + ddp_find_unused_parameters: true + ddp_timeout: 1800 + report_to: tensorboard + +# maps to DFlashConfig (modelopt/torch/speculative/config.py). +dflash: + dflash_block_size: 8 + dflash_num_anchors: 512 + dflash_use_torch_compile: false + dflash_self_logit_distillation: true + dflash_loss_decay_factor: 4.0 + dflash_architecture_config: + num_hidden_layers: 5 diff --git a/tests/gpu/torch/speculative/plugins/test_hf_dflash.py b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..230b67c45d --- /dev/null +++ b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for DFlash speculative decoding plugin. + +These tests require a CUDA GPU. CPU-only tests are in tests/unit/. +""" + +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, + } + return config + + +@pytest.fixture +def dflash_model(): + """Create a tiny DFlash model on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + return model + + +class TestDFlashModuleGPU: + """Test DFlash draft module forward pass on GPU.""" + + def test_dflash_module_forward_shape(self, dflash_model): + """Test that draft module produces correct output shape.""" + model = dflash_model + bsz = 2 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = ( + torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]) + .unsqueeze(0) + .expand(bsz, -1) + .cuda() + ) + + output = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + assert output.shape == (bsz, SEQ_LEN, hidden_size) + + def test_dflash_module_deterministic(self, dflash_model): + """Test that draft module produces identical outputs for same input.""" + model = dflash_model + model.eval() + bsz = 1 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).cuda() + + with torch.no_grad(): + out1 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + out2 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + assert torch.allclose(out1, out2) + + +class TestDFlashTrainingForwardGPU: + """Test DFlash training forward pass end-to-end on GPU.""" + + @pytest.fixture + def model(self): + """Create a tiny DFlash model in training mode on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + model.train() + return model + + def test_training_forward_returns_loss(self, model): + """Test that training forward returns a differentiable loss.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_returns_accuracy(self, model): + """Test that training forward returns train_acc.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "train_acc") + + def test_training_forward_with_labels(self, model): + """Test that labels are used for response-only loss masking.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + # Labels with -100 for first half (masked), real labels for second half + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + labels[:, SEQ_LEN // 2 :] = input_ids[:, SEQ_LEN // 2 :] + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_all_masked_labels(self, model): + """Test that all-masked labels produce zero loss without crashing.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert output.loss.item() == 0.0 + + def test_training_backward(self, model): + """Test that gradients flow to dflash_module.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + output.loss.backward() + + has_grad = False + for name, param in model.dflash_module.named_parameters(): + if param.grad is not None and param.grad.abs().sum() > 0: + has_grad = True + break + assert has_grad, "DFlash module should receive gradients" + + def test_eval_forward_uses_base_model(self, model): + """In eval mode, forward should use base model (not DFlash training).""" + model.eval() + bsz = 1 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + + with torch.no_grad(): + output = model(input_ids=input_ids) + assert output.logits.shape == (bsz, SEQ_LEN, model.config.vocab_size) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..50d3c9768b --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for DFlash speculative decoding plugin. + +GPU-dependent tests (training forward, module forward) are in tests/gpu/. +""" + +import os +from copy import deepcopy + +import torch +from _test_utils.torch.transformers_models import ( + get_tiny_llama, + tf_modelopt_state_and_output_tester, +) +from transformers import AutoModelForCausalLM + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG +from modelopt.torch.speculative.plugins.hf_dflash import ( + DFlashModule, + HFDFlashModel, + create_dflash_attention_mask, + create_dflash_loss_mask, +) + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, # use token 0 as mask for tiny model + } + return config + + +class TestDFlashConvert: + """Test DFlash model conversion.""" + + def test_convert_creates_dflash_model(self): + """Test that convert produces an HFDFlashModel.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert isinstance(model, HFDFlashModel) + + def test_convert_creates_dflash_module(self): + """Test that convert attaches a DFlashModule.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "dflash_module") + assert isinstance(model.dflash_module, DFlashModule) + + def test_convert_freezes_base_model(self): + """Test that base model parameters are frozen after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + for name, param in model.named_parameters(): + if "dflash_module" not in name: + assert not param.requires_grad, f"Base param {name} should be frozen" + + def test_convert_dflash_module_trainable(self): + """Test that DFlash module parameters are trainable after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + dflash_params = [(n, p) for n, p in model.named_parameters() if "dflash_module" in n] + assert len(dflash_params) > 0 + for name, param in dflash_params: + assert param.requires_grad, f"DFlash param {name} should be trainable" + + def test_convert_sets_target_layer_ids(self): + """Test that target layer IDs are set correctly.""" + model = get_tiny_llama(num_hidden_layers=8) + config = _get_dflash_config(num_layers=3) + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "target_layer_ids") + assert len(model.target_layer_ids) == 3 + for lid in model.target_layer_ids: + assert 0 <= lid < 8 + + def test_convert_sets_mask_token_id(self): + """Test that mask_token_id is set from config.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "mask_token_id") + assert model.mask_token_id == 0 + + +class TestDFlashSaveRestore: + """Test DFlash model save and restore.""" + + def test_save_and_restore(self, tmp_path): + """Test round-trip save/load preserves modelopt state and outputs.""" + mto.enable_huggingface_checkpointing() + model_ref = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model_ref, [("dflash", config)]) + + model_ref.save_pretrained(tmp_path / "modelopt_model") + assert os.path.exists(tmp_path / "modelopt_model/modelopt_state.pth") + + model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model") + assert isinstance(model_test, HFDFlashModel) + tf_modelopt_state_and_output_tester(model_ref, model_test) + + +class TestDFlashAttentionMask: + """Test DFlash attention mask construction.""" + + def test_mask_shape(self): + """Test mask has shape [1, 1, L, 2L].""" + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + assert mask.shape == (1, 1, SEQ_LEN, 2 * SEQ_LEN) + + def test_mask_context_strictly_previous_blocks(self): + """Context (left half): block B can only see blocks 0..B-1.""" + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] # [8, 16] + ctx_mask = mask_2d[:, :8] # context part + + # Block 0 (rows 0-3) should NOT see any context + assert (ctx_mask[:4, :] < 0).all() + + # Block 1 (rows 4-7) should see block 0 context only + assert (ctx_mask[4:8, :4] == 0).all() # can see block 0 + assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block + + def test_mask_noise_causal_within_block(self): + """Noise (right half): reverse-causal within same block, matching SpecForge. + + SpecForge uses j >= i: position 0 (anchor) sees all positions in block, + position B-1 sees only itself. Cross-block noise is fully masked. + """ + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] + noise_mask = mask_2d[:, 8:] # noise part + + # Block 0, position 0: can see all positions in block (0-3) + assert (noise_mask[0, :4] == 0).all() + + # Block 0, position 3: can only see position 3 + assert (noise_mask[3, :3] < 0).all() + assert noise_mask[3, 3] == 0 + + # Block 1 cannot see block 0 noise + assert (noise_mask[4:8, :4] < 0).all() + + def test_mask_values_are_zero_or_neg_inf(self): + """Test mask contains only 0 (attend) and -inf (mask).""" + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + unique_vals = mask.unique() + assert len(unique_vals) == 2 + assert 0.0 in unique_vals + assert unique_vals.min() == torch.finfo(torch.float32).min + + +class TestDFlashLossMask: + """Test DFlash loss mask construction.""" + + def test_loss_mask_shape(self): + """Test loss mask has shape [L].""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert mask.shape == (SEQ_LEN,) + + def test_loss_mask_excludes_block_zero(self): + """Test all positions in block 0 are masked out.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert (mask[:BLOCK_SIZE] == 0).all() + + def test_loss_mask_excludes_block_starts(self): + """Test block start positions are masked.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for i in range(0, SEQ_LEN, BLOCK_SIZE): + assert mask[i] == 0, f"Block start position {i} should be masked" + + def test_loss_mask_includes_non_start_positions(self): + """Test non-start positions in non-zero blocks are included.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for b in range(1, SEQ_LEN // BLOCK_SIZE): + for offset in range(1, BLOCK_SIZE): + pos = b * BLOCK_SIZE + offset + assert mask[pos] == 1, f"Position {pos} should be in loss" + + def test_loss_mask_count(self): + """Test total active positions matches expected count.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + num_blocks = SEQ_LEN // BLOCK_SIZE + expected = (num_blocks - 1) * (BLOCK_SIZE - 1) + assert mask.sum().item() == expected + + +class TestBuildTargetLayerIds: + """Test target layer selection.""" + + def test_single_draft_layer(self): + """Test single draft layer selects middle target layer.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 1) + assert len(ids) == 1 + assert ids[0] == 16 # middle layer + + def test_multiple_draft_layers(self): + """Test multiple draft layers are monotonically increasing and in bounds.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(36, 5) + assert len(ids) == 5 + assert ids == sorted(ids) + assert all(1 <= lid <= 33 for lid in ids) + + def test_layer_ids_spread(self): + """Test layer IDs have no duplicates.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 5) + assert len(ids) == 5 + assert len(set(ids)) == 5 diff --git a/tools/launcher/common/dflash/ar_eval_mtbench.sh b/tools/launcher/common/dflash/ar_eval_mtbench.sh new file mode 100644 index 0000000000..3971f61b56 --- /dev/null +++ b/tools/launcher/common/dflash/ar_eval_mtbench.sh @@ -0,0 +1,225 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MT-Bench per-category AR evaluation for DFlash checkpoints. +# Evaluates the latest checkpoint using ModelOpt's pseudo_speculative_generate +# with online (context-dependent) ground truth validation. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# +# Args: +# --ckpt_dir Path to directory containing checkpoint-* subdirs +# --block_size Block size for DFlash (default: 16) +# --num_layers Number of draft decoder layers (default: 5) +# --mask_token_id Mask token ID (default: auto-detect from checkpoint) +# --osl Output sequence length (default: 512) +# --steps Draft steps per block (default: block_size-1) +# --online Use online validation (default: true) + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt 2>&1 | tail -3 + +# Overlay DFlash code +for INSTALL_PATH in $(python3 -c " +import modelopt, os, site +paths = set() +paths.add(os.path.dirname(modelopt.__file__)) +for sp in site.getsitepackages(): + p = os.path.join(sp, 'modelopt') + if os.path.isdir(p): paths.add(p) +for p in paths: print(p) +"); do + cp -rf modules/Model-Optimizer/modelopt/torch/speculative/dflash ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true + cp -f modules/Model-Optimizer/modelopt/torch/speculative/plugins/hf_dflash.py ${INSTALL_PATH}/torch/speculative/plugins/ 2>/dev/null || true + cp -f modules/Model-Optimizer/modelopt/torch/speculative/plugins/__init__.py ${INSTALL_PATH}/torch/speculative/plugins/ 2>/dev/null || true + cp -f modules/Model-Optimizer/modelopt/torch/speculative/config.py ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true + cp -f modules/Model-Optimizer/modelopt/torch/speculative/mode.py ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true + cp -f modules/Model-Optimizer/modelopt/torch/speculative/utils.py ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true +done + +# Parse args +CKPT_DIR="" +BLOCK_SIZE=16 +NUM_LAYERS=5 +MASK_TOKEN_ID="" +OSL=512 +STEPS="" +ONLINE=true + +while [ $# -gt 0 ]; do + case "$1" in + --ckpt_dir) shift; CKPT_DIR="$1" ;; + --block_size) shift; BLOCK_SIZE="$1" ;; + --num_layers) shift; NUM_LAYERS="$1" ;; + --mask_token_id) shift; MASK_TOKEN_ID="$1" ;; + --osl) shift; OSL="$1" ;; + --steps) shift; STEPS="$1" ;; + --online) shift; ONLINE="$1" ;; + *) ;; + esac + shift +done + +if [ -z "$STEPS" ]; then + STEPS=$((BLOCK_SIZE - 1)) +fi + +MODEL=${HF_MODEL_CKPT} + +echo "=== DFlash MT-Bench AR Evaluation ===" +echo "Checkpoint dir: ${CKPT_DIR}" +echo "Model: ${MODEL}" +echo "Block size: ${BLOCK_SIZE}, Layers: ${NUM_LAYERS}" +echo "OSL: ${OSL}, Steps: ${STEPS}, Online: ${ONLINE}" + +# Find latest checkpoint +LAST_CKPT=$(ls -d ${CKPT_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) +if [ -z "$LAST_CKPT" ]; then + # Check for top-level model + if [ -f "${CKPT_DIR}/model.safetensors" ]; then + LAST_CKPT=${CKPT_DIR} + else + echo "ERROR: No checkpoints found in ${CKPT_DIR}" + exit 1 + fi +fi +echo "Evaluating: ${LAST_CKPT}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch, glob, os, json +from transformers import AutoModelForCausalLM, AutoTokenizer +from safetensors.torch import load_file +from datasets import load_dataset +from collections import defaultdict +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.plugins.transformers import HFARValidation + +mto.enable_huggingface_checkpointing() + +MODEL = '${MODEL}' +CKPT_PATH = '${LAST_CKPT}' +BLOCK_SIZE = ${BLOCK_SIZE} +NUM_LAYERS = ${NUM_LAYERS} +MASK_TOKEN_ID_STR = '${MASK_TOKEN_ID}' +OSL = ${OSL} +STEPS = ${STEPS} +ONLINE = '${ONLINE}' == 'true' + +# Auto-detect mask_token_id from checkpoint config +MASK_TOKEN_ID = int(MASK_TOKEN_ID_STR) if MASK_TOKEN_ID_STR else None +if MASK_TOKEN_ID is None: + cfg_path = os.path.join(CKPT_PATH, 'config.json') + if os.path.isfile(cfg_path): + with open(cfg_path) as f: + ckpt_cfg = json.load(f) + dflash_cfg = ckpt_cfg.get('dflash_config', {}) + MASK_TOKEN_ID = dflash_cfg.get('mask_token_id') + if MASK_TOKEN_ID is None: + MASK_TOKEN_ID = 151669 # default for Qwen3 + print(f'WARNING: Could not auto-detect mask_token_id, using default {MASK_TOKEN_ID}') +print(f'Using mask_token_id={MASK_TOKEN_ID}') + +# Use flash_attention_2 if available +try: + import flash_attn + ATTN_IMPL = 'flash_attention_2' +except ImportError: + ATTN_IMPL = 'sdpa' +print(f'Using attn_implementation={ATTN_IMPL}') + +tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) + +# Load MT-Bench by category +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +cat_samples = defaultdict(list) +for i in range(len(ds)): + cat = ds[i].get('category', 'unknown') + cat_samples[cat].append(ds[i]['prompt'][0]) +categories = sorted(cat_samples.keys()) +print(f'Categories: {categories}') +for c in categories: + print(f' {c}: {len(cat_samples[c])} samples') + +# Load model +model = AutoModelForCausalLM.from_pretrained( + MODEL, torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True, + attn_implementation=ATTN_IMPL, +) +config = { + 'dflash_block_size': BLOCK_SIZE, + 'dflash_architecture_config': { + 'num_hidden_layers': NUM_LAYERS, + 'mask_token_id': MASK_TOKEN_ID, + '_attn_implementation': ATTN_IMPL, + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load weights +sf_files = sorted(glob.glob(os.path.join(CKPT_PATH, 'model*.safetensors'))) +if sf_files: + state = {} + for f in sf_files: + state.update(load_file(f)) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + if dflash_keys: + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') + else: + model.dflash_module.load_state_dict(state, strict=False) + print(f'Loaded {len(state)} DFlash weights (no prefix)') +else: + print('ERROR: No safetensors found') + exit(1) + +model.eval() +validator = HFARValidation(model, tokenizer) + +# Evaluate per category +cat_ars = {} +all_ars = [] +for cat in categories: + ars = [] + for prompt in cat_samples[cat]: + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + if ONLINE: + _, ar = validator.validate_online(osl=OSL, input_ids=input_ids, steps=STEPS) + else: + _, ar = validator.validate(osl=OSL, input_ids=input_ids, steps=STEPS) + ars.append(ar) + all_ars.append(ar) + except Exception as e: + print(f' ERROR [{cat}]: {e}') + cat_ars[cat] = sum(ars) / len(ars) if ars else 0.0 + +avg_all = sum(all_ars) / len(all_ars) if all_ars else 0.0 +mode_str = 'online' if ONLINE else 'fixed GT' + +print(f'\n=== Results (OSL={OSL}, steps={STEPS}, {mode_str}) ===') +for c in categories: + print(f' {c:>12}: {cat_ars[c]:.4f}') +print(f'{\"ALL\":>14}: {avg_all:.4f}') +" + +report_result "PASS: MT-Bench AR evaluation" diff --git a/tools/launcher/common/dflash/ar_validate.sh b/tools/launcher/common/dflash/ar_validate.sh new file mode 100644 index 0000000000..b9df0b5c6f --- /dev/null +++ b/tools/launcher/common/dflash/ar_validate.sh @@ -0,0 +1,127 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash AR (Acceptance Rate) validation script. +# Loads a trained DFlash checkpoint and evaluates speculative decoding AR on MT-Bench. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# DFLASH_CKPT — path to the trained DFlash checkpoint +# DFLASH_BLOCK_SIZE — block size (default: 16) +# DFLASH_NUM_LAYERS — number of draft layers (default: 5) +# DFLASH_MASK_TOKEN_ID — mask token ID (default: auto-detect) +# NUM_SAMPLES — number of MT-Bench samples to evaluate (default: 20) + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh +trap 'error_handler $0 $LINENO' ERR + +pip install --upgrade "transformers>=4.57" 2>&1 | tail -3 + +DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} +DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} +NUM_SAMPLES=${NUM_SAMPLES:-20} + +# Build mask_token_id arg +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +echo "DFlash checkpoint: ${DFLASH_CKPT}" +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${DFLASH_CKPT}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + # Try with dflash_module prefix first (ModelOpt format) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + if dflash_keys: + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') + else: + # No prefix — SpecForge format, load directly into dflash_module + result = model.dflash_module.load_state_dict(state, strict=False) + loaded = len(state) - len(result.unexpected_keys) + print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh new file mode 100644 index 0000000000..f30dd7292c --- /dev/null +++ b/tools/launcher/common/dflash/online_training.sh @@ -0,0 +1,255 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash online training + AR validation script for the ModelOpt Launcher. +# Trains a DFlash draft model alongside the frozen target model, +# then evaluates acceptance rate on MT-Bench. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# +# Optional env vars: +# NUM_AR_SAMPLES — number of MT-Bench samples for AR validation (default: 20, 0 to skip) +# +# All other args are passed through to launch_train.sh. + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt +pip install huggingface-hub>=1.2.1 +export PATH=$PATH:/workspace/.local/bin + +################################################################################################### + +trap 'error_handler $0 $LINENO' ERR + +# Auto-detect head node IP for multi-node training +if [ -z "$HEAD_NODE_IP" ]; then + # Method 1: scontrol (works outside container) + HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) + # Method 2: SLURM_LAUNCH_NODE_IPADDR (some Slurm versions) + HEAD_NODE_IP=${HEAD_NODE_IP:-$SLURM_LAUNCH_NODE_IPADDR} + # Method 3: Parse SLURM_NODELIST and resolve via Python + if [ -z "$HEAD_NODE_IP" ] && [ -n "$SLURM_JOB_NODELIST" ]; then + HEAD_NODE_IP=$(python3 -c " +import socket, re, os +nl = os.environ.get('SLURM_JOB_NODELIST', '') +# Extract first hostname: 'node[001-002]' -> 'node001', 'node001,node002' -> 'node001' +m = re.match(r'([a-zA-Z0-9-]+?)(?:\[(\d+))?', nl) +if m: + host = m.group(1) + (m.group(2) or '') + try: + print(socket.gethostbyname(host)) + except: + print(host) +" 2>/dev/null) + fi + # Method 4: Use rank 0's hostname + if [ -z "$HEAD_NODE_IP" ] && [ "${SLURM_PROCID:-0}" = "0" ]; then + HEAD_NODE_IP=$(hostname -I 2>/dev/null | awk '{print $1}') + fi + export HEAD_NODE_IP + echo "Auto-detected HEAD_NODE_IP: ${HEAD_NODE_IP}" +fi + +# Parse DFlash-specific args from the command line for AR validation +DFLASH_BLOCK_SIZE=16 +DFLASH_NUM_LAYERS=5 +DFLASH_MASK_TOKEN_ID="" +OUTPUT_DIR="" +for arg in "$@"; do + case "$arg" in + --dflash_block_size) next_is_block_size=1 ;; + --dflash_num_layers) next_is_num_layers=1 ;; + --dflash_mask_token_id) next_is_mask_id=1 ;; + --output_dir) next_is_output=1 ;; + *) + if [ "$next_is_block_size" = "1" ]; then DFLASH_BLOCK_SIZE="$arg"; next_is_block_size=0; fi + if [ "$next_is_num_layers" = "1" ]; then DFLASH_NUM_LAYERS="$arg"; next_is_num_layers=0; fi + if [ "$next_is_mask_id" = "1" ]; then DFLASH_MASK_TOKEN_ID="$arg"; next_is_mask_id=0; fi + if [ "$next_is_output" = "1" ]; then OUTPUT_DIR="$arg"; next_is_output=0; fi + ;; + esac +done + +# Step 1: Training +# Build config overrides from CLI args and env vars +CONFIG_FILE=${DFLASH_CONFIG:-modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml} + +# Parse --key value pairs into OmegaConf dotlist overrides +OVERRIDES="model.model_name_or_path=${HF_MODEL_CKPT}" +while [ $# -gt 0 ]; do + case "$1" in + --data) shift; OVERRIDES="$OVERRIDES data.data_path=$1" ;; + --output_dir) shift; OVERRIDES="$OVERRIDES training.output_dir=$1"; OUTPUT_DIR="$1" ;; + --num_epochs) shift; OVERRIDES="$OVERRIDES training.num_train_epochs=$1" ;; + --lr) shift; OVERRIDES="$OVERRIDES training.learning_rate=$1" ;; + --training_seq_len) shift; OVERRIDES="$OVERRIDES training.training_seq_len=$1" ;; + --save_steps) shift; OVERRIDES="$OVERRIDES training.save_steps=$1" ;; + --log_steps) shift; OVERRIDES="$OVERRIDES training.logging_steps=$1" ;; + --disable_tqdm) shift; OVERRIDES="$OVERRIDES training.disable_tqdm=$1" ;; + --ar_validate_steps) shift; OVERRIDES="$OVERRIDES training.ar_validate_steps=$1" ;; + --dflash_block_size) shift; OVERRIDES="$OVERRIDES dflash.dflash_block_size=$1"; DFLASH_BLOCK_SIZE="$1" ;; + --dflash_num_layers) shift; OVERRIDES="$OVERRIDES dflash.dflash_architecture_config.num_hidden_layers=$1"; DFLASH_NUM_LAYERS="$1" ;; + --dflash_mask_token_id) shift; OVERRIDES="$OVERRIDES dflash.dflash_architecture_config.mask_token_id=$1"; DFLASH_MASK_TOKEN_ID="$1" ;; + --dflash_num_anchors) shift; OVERRIDES="$OVERRIDES dflash.dflash_num_anchors=$1" ;; + --dflash_loss_decay_gamma) shift; OVERRIDES="$OVERRIDES dflash.dflash_loss_decay_factor=$1" ;; + --num_nodes) shift; NUM_NODES="$1" ;; + --config) shift; CONFIG_FILE="$1" ;; + *) ;; + esac + shift +done + +# Add tensorboard logging dir +if [ -n "$OUTPUT_DIR" ]; then + OVERRIDES="$OVERRIDES training.logging_dir=${OUTPUT_DIR}/tensorboard" +fi + +bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ + --config ${CONFIG_FILE} \ + --num_nodes ${NUM_NODES:-1} \ + --head_node_ip ${HEAD_NODE_IP:-} \ + ${OVERRIDES} + +# Step 2: AR Validation +NUM_AR_SAMPLES=${NUM_AR_SAMPLES:-20} +if [ "${NUM_AR_SAMPLES}" = "0" ]; then + echo "Skipping AR validation (NUM_AR_SAMPLES=0)" + exit 0 +fi + +if [ -z "$OUTPUT_DIR" ]; then + echo "WARNING: --output_dir not found in args, skipping export and AR validation" + exit 0 +fi + +# Step 2: Export checkpoint to z-lab HF format +EXPORT_DIR=${OUTPUT_DIR}/export +echo "" +echo "=== Exporting DFlash checkpoint ===" +echo "Source: ${OUTPUT_DIR}" +echo "Export: ${EXPORT_DIR}" + +python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ + --model_path ${OUTPUT_DIR} \ + --export_path ${EXPORT_DIR} \ + || echo "WARNING: Export failed, continuing with AR validation" + +echo "" +echo "Export contents:" +ls -la ${EXPORT_DIR}/ 2>/dev/null || echo "No export dir" + +# Step 3: AR Validation +# Build mask_token_id config +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "" +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +# Prefer exported checkpoint (no prefix), fall back to training output (with prefix) +if [ -f "${EXPORT_DIR}/model.safetensors" ]; then + AR_CKPT=${EXPORT_DIR} + echo "Using exported checkpoint: ${AR_CKPT}" +else + AR_CKPT=${OUTPUT_DIR} + echo "Using training checkpoint: ${AR_CKPT}" +fi +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_AR_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${AR_CKPT}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + if dflash_keys: + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') + else: + result = model.dflash_module.load_state_dict(state, strict=False) + loaded = len(state) - len(result.unexpected_keys) + print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_AR_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" + +################################################################################################### diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml new file mode 100644 index 0000000000..73d608e492 --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -0,0 +1,63 @@ +# DFlash online speculative decoding training for Qwen3-8B. +# +# Trains a DFlash draft model (block diffusion) using the frozen target model +# to extract multi-layer hidden states on the fly, then evaluates AR on MT-Bench. +# +# 2-step pipeline: +# task_0: Online DFlash training (8 nodes, 64 GPUs) +# task_1: MT-Bench per-category AR evaluation (1 GPU) +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes + +job_name: Qwen3-8B_DFlash_online +pipeline: + # Step 1: Online DFlash training + task_0: + script: common/dflash/online_training.sh + args: + - --data /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/default.jsonl + - --output_dir /scratchspace/dflash_bs16 + - --num_epochs 5 + - --lr 6e-4 + - --training_seq_len 4096 + - --save_steps 5000 + - --log_steps 1000 + - --disable_tqdm True + - --ar_validate_steps 0 + - --dflash_block_size 16 + - --dflash_num_layers 5 + - --dflash_mask_token_id 151669 + - --dflash_num_anchors 512 + - --dflash_loss_decay_gamma 7 + - --num_nodes 8 + environment: + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - NUM_AR_SAMPLES: "0" + slurm_config: + _factory_: "slurm_factory" + nodes: 8 + ntasks_per_node: 1 + gpus_per_node: 8 + + # Step 2: MT-Bench per-category AR evaluation (ModelOpt online validation) + task_1: + script: common/dflash/ar_eval_mtbench.sh + args: + - --ckpt_dir /scratchspace/dflash_bs16 + - --block_size 16 + - --num_layers 5 + - --mask_token_id 151669 + - --osl 512 + - --steps 15 + - --online true + environment: + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 From 306186859ffbc7f69e50b3e4fd07db804c65293d Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 8 Apr 2026 16:04:35 -0700 Subject: [PATCH 02/24] refactor: simplify DFlash --- examples/speculative_decoding/main.py | 10 - examples/speculative_decoding/train_dflash.py | 319 ------------------ .../torch/speculative/plugins/hf_dflash.py | 75 +--- 3 files changed, 18 insertions(+), 386 deletions(-) delete mode 100644 examples/speculative_decoding/train_dflash.py diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 1bb70f871f..5774ad5a03 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -229,16 +229,6 @@ def train(): torch_dtype="auto", trust_remote_code=model_args.trust_remote_code, ) - # DFlash: re-create rotary embeddings with meta-tensor buffers on CPU. - # inv_freq is computed (not saved in checkpoints), stays on meta after restore. - if training_args.mode == "dflash": - for mod in model.modules(): - if hasattr(mod, "rotary_emb"): - rotary = mod.rotary_emb - if any(b.is_meta for b in rotary.buffers()): - cfg = getattr(rotary, "config", None) - if cfg is not None: - mod.rotary_emb = type(rotary)(config=cfg, device="cpu") tokenizer = transformers.AutoTokenizer.from_pretrained( model_load_path, trust_remote_code=model_args.trust_remote_code ) diff --git a/examples/speculative_decoding/train_dflash.py b/examples/speculative_decoding/train_dflash.py deleted file mode 100644 index 20be8a85f0..0000000000 --- a/examples/speculative_decoding/train_dflash.py +++ /dev/null @@ -1,319 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Standalone DFlash training script using SpecForge's data pipeline. - -Uses SpecForge's tokenizer template + offset-mapping loss mask for data -preprocessing, and ModelOpt's DFlash module for the draft model. This -isolates data pipeline differences from model architecture differences. - -Usage: - torchrun --nproc_per_node=8 train_dflash.py \ - --model /path/to/Qwen3-8B \ - --data /path/to/train.jsonl \ - --chat-template qwen \ - --block-size 16 \ - --num-draft-layers 5 \ - --num-epochs 3 \ - --lr 1e-4 \ - --output-dir /path/to/output -""" - -import argparse -import math -import os - -import torch -import torch.distributed as dist -from datasets import load_dataset -from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoModelForCausalLM, AutoTokenizer - -import modelopt.torch.opt as mto -import modelopt.torch.speculative as mtsp - - -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="DFlash training with SpecForge data pipeline") - parser.add_argument("--model", type=str, required=True, help="Target model path") - parser.add_argument("--data", type=str, required=True, help="Training data JSONL path") - parser.add_argument("--chat-template", type=str, default="qwen", help="Chat template name") - parser.add_argument("--block-size", type=int, default=16) - parser.add_argument("--num-draft-layers", type=int, default=5) - parser.add_argument("--mask-token-id", type=int, default=None) - parser.add_argument("--max-length", type=int, default=512) - parser.add_argument("--num-epochs", type=int, default=3) - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--warmup-ratio", type=float, default=0.01) - parser.add_argument("--log-interval", type=int, default=100) - parser.add_argument("--save-interval", type=int, default=0, help="0 = save at end only") - parser.add_argument("--output-dir", type=str, required=True) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--num-ar-samples", type=int, default=20, help="AR validation samples") - return parser.parse_args() - - -def is_rank0(): - """Check if current process is rank 0.""" - return not dist.is_initialized() or dist.get_rank() == 0 - - -def print_rank0(msg): - """Print only on rank 0.""" - if is_rank0(): - print(msg, flush=True) - - -def build_dataset(tokenizer, data_path, chat_template_name, max_length): - """Build dataset using SpecForge's data pipeline. - - Uses SpecForge's GeneralParser to tokenize conversations with the - proper chat template and compute offset-mapping-based loss masks. - """ - from specforge.data.parse import GeneralParser - from specforge.data.template import TEMPLATE_REGISTRY - - template = TEMPLATE_REGISTRY.get(chat_template_name) - parser = GeneralParser(tokenizer, template) - - raw_dataset = load_dataset("json", data_files=data_path)["train"] - - processed = {"input_ids": [], "loss_mask": []} - skipped = 0 - for sample in raw_dataset: - convs = sample.get("conversations", sample.get("messages", [])) - if not convs: - skipped += 1 - continue - try: - input_ids, loss_mask = parser.parse(convs, max_length=max_length) - processed["input_ids"].append(input_ids) - processed["loss_mask"].append(loss_mask) - except Exception: - skipped += 1 - - print_rank0(f"Processed {len(processed['input_ids'])} samples, skipped {skipped}") - return processed - - -class DFlashDataset(torch.utils.data.Dataset): - """Simple dataset wrapping tokenized input_ids and loss_mask.""" - - def __init__(self, data): - self.input_ids = data["input_ids"] - self.loss_mask = data["loss_mask"] - - def __len__(self): - return len(self.input_ids) - - def __getitem__(self, idx): - return { - "input_ids": self.input_ids[idx], - "loss_mask": self.loss_mask[idx], - } - - -def collate_fn(batch): - """Collate batch of samples.""" - input_ids = torch.stack([b["input_ids"] for b in batch]) - loss_mask = torch.stack([b["loss_mask"] for b in batch]) - return {"input_ids": input_ids, "loss_mask": loss_mask} - - -def train(args): - """Main training loop.""" - # Init distributed - dist.init_process_group("nccl") - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - - torch.manual_seed(args.seed) - mto.enable_huggingface_checkpointing() - - # Load model - print_rank0(f"Loading model: {args.model}") - model = AutoModelForCausalLM.from_pretrained( - args.model, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True - ) - tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) - - # Detect mask_token_id - mask_token_id = args.mask_token_id - if mask_token_id is None: - if hasattr(tokenizer, "mask_token_id") and tokenizer.mask_token_id is not None: - mask_token_id = tokenizer.mask_token_id - elif hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: - mask_token_id = tokenizer.pad_token_id - else: - mask_token_id = tokenizer.eos_token_id - print_rank0(f"mask_token_id: {mask_token_id}") - - # Convert to DFlash - config = { - "dflash_block_size": args.block_size, - "dflash_use_torch_compile": False, - "dflash_architecture_config": { - "num_hidden_layers": args.num_draft_layers, - "mask_token_id": mask_token_id, - }, - } - mtsp.convert(model, [("dflash", config)]) - print_rank0( - f"DFlash module created: {sum(p.numel() for p in model.dflash_module.parameters()):,} params" - ) - - # Build dataset using SpecForge pipeline - print_rank0("Building dataset with SpecForge pipeline...") - data = build_dataset(tokenizer, args.data, args.chat_template, args.max_length) - - # Filter samples with too few loss tokens - min_loss_tokens = 2 * args.block_size - filtered_ids = [] - filtered_masks = [] - for i in range(len(data["input_ids"])): - if data["loss_mask"][i].sum() >= min_loss_tokens: - filtered_ids.append(data["input_ids"][i]) - filtered_masks.append(data["loss_mask"][i]) - print_rank0(f"After filtering: {len(filtered_ids)} samples (min {min_loss_tokens} loss tokens)") - data = {"input_ids": filtered_ids, "loss_mask": filtered_masks} - - dataset = DFlashDataset(data) - sampler = DistributedSampler(dataset, shuffle=True) - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - sampler=sampler, - collate_fn=collate_fn, - num_workers=2, - pin_memory=True, - drop_last=True, - ) - - # Wrap with DDP - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[local_rank], - find_unused_parameters=True, - ) - raw_model = model.module - - # Optimizer — only train dflash_module - optimizer = torch.optim.AdamW( - [p for p in raw_model.dflash_module.parameters() if p.requires_grad], - lr=args.lr, - weight_decay=0.0, - ) - - # LR scheduler - steps_per_epoch = len(dataloader) - total_steps = args.num_epochs * steps_per_epoch - warmup_steps = int(total_steps * args.warmup_ratio) - - def lr_lambda(step): - if step < warmup_steps: - return step / max(warmup_steps, 1) - progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) - return 0.5 * (1.0 + math.cos(math.pi * progress)) - - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - - print_rank0(f"Training: {total_steps} steps, {warmup_steps} warmup, {steps_per_epoch}/epoch") - - # Training loop - global_step = 0 - for epoch in range(args.num_epochs): - sampler.set_epoch(epoch) - model.train() - - for batch in dataloader: - input_ids = batch["input_ids"].to(device) - loss_mask = batch["loss_mask"].to(device) - - # Create labels from loss_mask: -100 for masked positions - labels = input_ids.clone() - labels[loss_mask == 0] = -100 - - output = model( - input_ids=input_ids, - attention_mask=torch.ones_like(input_ids), - labels=labels, - ) - - loss = output.loss - loss.backward() - optimizer.step() - scheduler.step() - optimizer.zero_grad() - - global_step += 1 - - if global_step % args.log_interval == 0: - acc = output.train_acc[0][0] if hasattr(output, "train_acc") else 0.0 - lr = scheduler.get_last_lr()[0] - print_rank0( - f"Step {global_step} | loss={loss.item():.4f} | acc={acc:.4f} | lr={lr:.2e}" - ) - - if args.save_interval > 0 and global_step % args.save_interval == 0: - if is_rank0(): - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - raw_model.save_pretrained(save_path) - print_rank0(f"Saved checkpoint: {save_path}") - - # Save final model - if is_rank0(): - os.makedirs(args.output_dir, exist_ok=True) - raw_model.save_pretrained(args.output_dir) - print_rank0(f"Saved final model: {args.output_dir}") - - dist.barrier() - - # AR validation on rank 0 - if is_rank0() and args.num_ar_samples > 0: - print_rank0("\n=== AR Validation ===") - model.eval() - from modelopt.torch.speculative.plugins.transformers import HFARValidation - - validator = HFARValidation(raw_model, tokenizer) - ds = load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"] - - ars = [] - for i in range(min(args.num_ar_samples, len(ds))): - prompt = ds[i]["prompt"][0] - chat = [{"role": "user", "content": prompt}] - text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) - inp = tokenizer(text, return_tensors="pt").input_ids.to(device) - try: - _, ar = validator.validate(osl=32, input_ids=inp, steps=3) - ars.append(ar) - print_rank0(f" AR={ar:.2f} | {prompt[:60]}") - except Exception as e: - print_rank0(f" ERROR | {prompt[:60]}... | {e}") - - if ars: - avg = sum(ars) / len(ars) - print_rank0("\n==== DFlash AR Results ====") - print_rank0(f"Average AR: {avg:.4f}") - print_rank0(f"Min: {min(ars):.4f}, Max: {max(ars):.4f}") - - dist.destroy_process_group() - - -if __name__ == "__main__": - args = parse_args() - train(args) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 2601fd3431..fe3b923878 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -26,63 +26,22 @@ Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) """ -import importlib - import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig, PreTrainedModel + +# DFlash draft model uses Qwen3 components regardless of the target model. +# This matches z-lab's implementation which inherits from Qwen3PreTrainedModel. +from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS +from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS +from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding as _ROTARY_CLS +from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half from transformers.utils import ModelOutput from ..dflash.conversion import DFlashDMRegistry from ..dflash.dflash_model import DFlashModel - -def _resolve_model_components(model_type): - """Resolve MLP, RMSNorm, RotaryEmbedding from the base model's transformers module. - - Falls back to Llama components if the model type is unknown. - """ - fallback = "llama" - model_type = model_type or fallback - try: - mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}") - except (ImportError, ModuleNotFoundError): - mod = importlib.import_module(f"transformers.models.{fallback}.modeling_{fallback}") - model_type = fallback - - prefix = model_type.capitalize() - # Handle multi-word model types (e.g., "qwen3" -> "Qwen3") - for attr in dir(mod): - if attr.lower() == f"{model_type}mlp": - prefix = attr.replace("MLP", "") - break - - mlp_cls = getattr(mod, f"{prefix}MLP", None) - norm_cls = getattr(mod, f"{prefix}RMSNorm", None) - rotary_cls = getattr(mod, f"{prefix}RotaryEmbedding", None) - rotate_half_fn = getattr(mod, "rotate_half", None) - - # Fallback to Llama if any component is missing - if not all([mlp_cls, norm_cls, rotary_cls, rotate_half_fn]): - from transformers.models.llama.modeling_llama import ( - LlamaMLP, - LlamaRMSNorm, - LlamaRotaryEmbedding, - ) - from transformers.models.llama.modeling_llama import rotate_half as _rotate_half - - mlp_cls = mlp_cls or LlamaMLP - norm_cls = norm_cls or LlamaRMSNorm - rotary_cls = rotary_cls or LlamaRotaryEmbedding - rotate_half_fn = rotate_half_fn or _rotate_half - - return mlp_cls, norm_cls, rotary_cls, rotate_half_fn - - -# Default to Llama components; overridden per-model during convert() -_MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components("llama") - __all__ = ["HFDFlashModel"] @@ -271,6 +230,17 @@ def __init__(self, config): # SpecForge's DFlashDraftModel uses Qwen3PreTrainedModel.post_init() which does this. self._init_weights(config) + def _apply(self, fn, recurse=True): + """Override _apply to handle meta-tensor rotary buffers during .to(device). + + After checkpoint restore, rotary inv_freq buffers may be on meta device + (they are computed, not saved). Re-create rotary_emb before applying the + device/dtype transfer to avoid 'Cannot copy out of meta tensor' errors. + """ + if hasattr(self, "rotary_emb") and any(b.is_meta for b in self.rotary_emb.buffers()): + self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device="cpu") + return super()._apply(fn, recurse) + def _init_weights(self, config): """Initialize weights matching HF PreTrainedModel._init_weights.""" std = getattr(config, "initializer_range", 0.02) @@ -284,9 +254,6 @@ def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=N """Forward matching SpecForge DFlashDraftModel.forward.""" hidden_states = noise_embedding target_hidden = self.hidden_norm(self.fc(target_hidden)) - # Re-create rotary_emb on correct device if buffers are on meta (checkpoint resume) - if any(b.is_meta for b in self.rotary_emb.buffers()): - self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device=hidden_states.device) position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer in self.layers: @@ -505,12 +472,6 @@ def modify(self, config): self._find_base_model_parts() - # Resolve model-specific components (MLP, RMSNorm, RotaryEmbedding) - # from the base model's architecture for weight compatibility - global _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half - _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components( - getattr(base_config, "model_type", "llama") - ) self.dflash_module = DFlashModule(self.dflash_config) self.dflash_module.to(self._base_model.dtype).to( next(self._base_model.layers[-1].parameters()).device From c694c5307e0e71bd5f1a96b7cbeeac6a9be8909e Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 8 Apr 2026 17:00:31 -0700 Subject: [PATCH 03/24] refactor: simplify DFlash implementation and training pipeline - Use Qwen3 components directly (no dynamic _resolve_model_components) - Add sliding window attention support (config.layer_types) - Move rotary meta buffer fix to DFlashModule._apply() with detailed docs - Remove DFlash-specific resume code from main.py (standard resume works) - Remove unused train_dflash.py and ar_validate.sh - Simplify online_training.sh: direct accelerate launch, no arg parsing - YAML uses OmegaConf overrides directly (matching eagle3 pattern) - Update README to point to launcher example - Add extension docs for MoE and MLA support Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/README.md | 11 +- examples/speculative_decoding/main.py | 31 +-- .../torch/speculative/plugins/hf_dflash.py | 68 ++++- tools/launcher/common/dflash/ar_validate.sh | 127 ---------- .../launcher/common/dflash/online_training.sh | 235 +++--------------- .../Qwen/Qwen3-8B/hf_online_dflash.yaml | 45 ++-- 6 files changed, 126 insertions(+), 391 deletions(-) delete mode 100644 tools/launcher/common/dflash/ar_validate.sh diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index b51c90c55c..a5e90640c1 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -359,14 +359,14 @@ using masked parallel prediction with KV injection from the target model's hidde ### Quick Start +For a complete end-to-end example (training + evaluation), see the +[launcher example](../../tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml): + ```bash -./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/dflash.yaml \ - model.model_name_or_path=/path/to/Qwen3-8B \ - data.data_path=/path/to/train.jsonl \ - training.output_dir=/path/to/output +uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes ``` -### Key Configuration (dflash.yaml) +### Key Configuration ([dflash.yaml](../../modelopt_recipes/general/speculative_decoding/dflash.yaml)) | Field | Default | Description | |-------|---------|-------------| @@ -376,6 +376,7 @@ using masked parallel prediction with KV injection from the target model's hidde | `dflash.dflash_self_logit_distillation` | true | Use logit distillation from target | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | +| `training.answer_only_loss` | false | Mask loss on non-assistant tokens | ### Export diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 5774ad5a03..c029b83bd3 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -167,8 +167,6 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic def train(): - import json - config_path, overrides = _parse_cli() hf_cfg, eagle_cfg, dflash_cfg = _load_config(config_path, overrides) @@ -298,34 +296,7 @@ def train(): ) print_rank_0("Start training...") - if checkpoint and not os.path.isfile( - os.path.join(training_args.output_dir, "model.safetensors") - ): - # Resume from checkpoint subdir: try full resume first, fall back to - # partial resume (model weights + trainer state, fresh optimizer) if - # the optimizer state doesn't match. - try: - trainer.train(resume_from_checkpoint=checkpoint) - except ValueError as e: - if "parameter group" in str(e): - print_rank_0( - f"Optimizer state mismatch: {e}\n" - f"Resuming with fresh optimizer from {checkpoint}" - ) - state_file = os.path.join(checkpoint, "trainer_state.json") - if os.path.isfile(state_file): - state = json.load(open(state_file)) - resumed_step = state.get("global_step", 0) - resumed_max_steps = state.get("max_steps", -1) - print_rank_0(f"Resuming from step {resumed_step}/{resumed_max_steps}") - if resumed_max_steps > 0: - training_args.max_steps = resumed_max_steps - trainer.state = trainer.state.load_from_json(state_file) - trainer.train() - else: - raise - else: - trainer.train(resume_from_checkpoint=checkpoint) + trainer.train(resume_from_checkpoint=checkpoint) trainer.save_state() trainer.save_model(training_args.output_dir) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index fe3b923878..5ec932e987 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -15,15 +15,41 @@ """DFlash speculative decoding plugin for HuggingFace models. -Matches the reference SpecForge implementation (github.com/sgl-project/SpecForge PR #415). +Matches the reference SpecForge implementation (github.com/sgl-project/SpecForge). Architecture: - Feature Fusion: multi-layer target hidden states → FC + RMSNorm - KV Injection: fused features as K/V in every draft layer with QK-norm -- Parallel Drafting: mask_token_id for unknown positions, causal within blocks -- Loss: hard CE on input_ids[i] (position i predicts token i) +- Parallel Drafting: mask_token_id for unknown positions, bidirectional within blocks +- Random anchor sampling with exponential loss decay +- Logit distillation from target model Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) + +Draft model components: + The draft model currently uses Qwen3 components (MLP, RMSNorm, RotaryEmbedding) + from ``transformers.models.qwen3``, matching z-lab's reference checkpoint format. + Qwen3 sliding window attention is supported via ``config.layer_types``. + The draft architecture is independent of the target model — any target model can + be used as long as it provides hidden states. + + To add support for other draft architectures: + + Qwen3MoE (MoE MLP): + 1. Import ``Qwen3MoeMLP`` from ``transformers.models.qwen3_moe`` + 2. Add a config flag (e.g., ``use_moe``) in ``dflash_architecture_config`` + 3. In ``DFlashDecoderLayer.__init__``, select MLP based on the flag + RMSNorm, RotaryEmbedding, and attention are shared across Qwen3 variants. + + MLA (Multi-head Latent Attention, e.g., DeepseekV3/Kimi-K2): + MLA compresses K/V into a low-rank latent space. To support MLA in DFlash: + 1. Replace ``DFlashAttention`` with an MLA-aware variant that handles + compressed KV injection (project target_hidden through MLA's down/up + projections before concatenating with noise K/V) + 2. Handle lazy rope initialization (see ``_setup_kimi_k2_decoder`` in + ``modelopt.torch.speculative.utils`` for the EAGLE3 approach) + 3. The ``_apply`` meta buffer fix in ``DFlashModule`` already handles the + lazy rope pattern needed for MLA models. """ import torch @@ -98,7 +124,12 @@ def __init__(self, config, layer_idx): # Resolve HF attention function matching SpecForge's dispatch self._attn_fn = None - self.sliding_window = None + # Qwen3 uses sliding window attention on some layers (config.layer_types) + if hasattr(config, "layer_types") and hasattr(config, "sliding_window"): + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None + else: + self.sliding_window = None def _get_attn_fn(self): """Lazily resolve the HF attention function.""" @@ -231,11 +262,30 @@ def __init__(self, config): self._init_weights(config) def _apply(self, fn, recurse=True): - """Override _apply to handle meta-tensor rotary buffers during .to(device). - - After checkpoint restore, rotary inv_freq buffers may be on meta device - (they are computed, not saved). Re-create rotary_emb before applying the - device/dtype transfer to avoid 'Cannot copy out of meta tensor' errors. + """Override _apply to fix meta-tensor rotary buffers before device transfer. + + Why this is needed: + When resuming from a checkpoint, ModelOpt's ``enable_huggingface_checkpointing`` + restores the model architecture from ``modelopt_state.pth``. During this restore, + ``DFlashModule.__init__`` runs and creates ``rotary_emb`` with its ``inv_freq`` + buffer. However, ``inv_freq`` is a *computed* buffer (derived from ``rope_theta`` + and ``head_dim``), not a learned parameter, so it is NOT saved in + ``model.safetensors``. After ``from_pretrained`` loads the saved weights, all + learned parameters are materialized on CPU, but ``inv_freq`` remains on the + **meta device** (a placeholder with shape but no data). + + Later, HF Trainer calls ``model.to(device)`` which internally calls ``_apply`` + on every submodule. When ``_apply`` reaches the meta ``inv_freq`` buffer, it + raises ``NotImplementedError: Cannot copy out of meta tensor``. + + Fix: + Before ``super()._apply()`` transfers tensors to the target device, we check + if ``rotary_emb`` has any meta buffers. If so, we re-create it on CPU using + the stored config (``_rotary_config``). This produces a real ``inv_freq`` tensor + with correct values, which ``_apply`` can then safely move to GPU. + + This approach is transparent to the training script (``main.py``) — no + mode-specific resume logic is needed there. """ if hasattr(self, "rotary_emb") and any(b.is_meta for b in self.rotary_emb.buffers()): self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device="cpu") diff --git a/tools/launcher/common/dflash/ar_validate.sh b/tools/launcher/common/dflash/ar_validate.sh deleted file mode 100644 index b9df0b5c6f..0000000000 --- a/tools/launcher/common/dflash/ar_validate.sh +++ /dev/null @@ -1,127 +0,0 @@ -#!/bin/bash - -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DFlash AR (Acceptance Rate) validation script. -# Loads a trained DFlash checkpoint and evaluates speculative decoding AR on MT-Bench. -# -# Required env vars: -# HF_MODEL_CKPT — path to the target HuggingFace model -# DFLASH_CKPT — path to the trained DFlash checkpoint -# DFLASH_BLOCK_SIZE — block size (default: 16) -# DFLASH_NUM_LAYERS — number of draft layers (default: 5) -# DFLASH_MASK_TOKEN_ID — mask token ID (default: auto-detect) -# NUM_SAMPLES — number of MT-Bench samples to evaluate (default: 20) - -SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" -source ${SCRIPT_DIR}/../service_utils.sh -trap 'error_handler $0 $LINENO' ERR - -pip install --upgrade "transformers>=4.57" 2>&1 | tail -3 - -DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} -DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} -NUM_SAMPLES=${NUM_SAMPLES:-20} - -# Build mask_token_id arg -if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then - MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," -else - MASK_ARG="" -fi - -echo "=== DFlash AR Validation ===" -echo "Target model: ${HF_MODEL_CKPT}" -echo "DFlash checkpoint: ${DFLASH_CKPT}" -echo "Block size: ${DFLASH_BLOCK_SIZE}" -echo "Draft layers: ${DFLASH_NUM_LAYERS}" -echo "Samples: ${NUM_SAMPLES}" - -CUDA_VISIBLE_DEVICES=0 python3 -c " -import torch -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from modelopt.torch.speculative.plugins.transformers import HFARValidation -import modelopt.torch.opt as mto -import modelopt.torch.speculative as mtsp - -mto.enable_huggingface_checkpointing() - -model = AutoModelForCausalLM.from_pretrained( - '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) - -config = { - 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, - 'dflash_architecture_config': { - 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, - ${MASK_ARG} - }, - 'dflash_use_torch_compile': False, -} -mtsp.convert(model, [('dflash', config)]) - -# Load trained DFlash weights -import glob -from safetensors.torch import load_file -ckpt_files = sorted(glob.glob('${DFLASH_CKPT}/model*.safetensors')) -if ckpt_files: - state = {} - for f in ckpt_files: - state.update(load_file(f)) - # Try with dflash_module prefix first (ModelOpt format) - dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} - if dflash_keys: - model.load_state_dict(dflash_keys, strict=False) - print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') - else: - # No prefix — SpecForge format, load directly into dflash_module - result = model.dflash_module.load_state_dict(state, strict=False) - loaded = len(state) - len(result.unexpected_keys) - print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') -else: - print('WARNING: No checkpoint files found, using random weights') - -model.eval() -validator = HFARValidation(model, tokenizer) - -ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] -num_samples = min(${NUM_SAMPLES}, len(ds)) - -ars = [] -for i in range(num_samples): - prompt = ds[i]['prompt'][0] - chat = [{'role': 'user', 'content': prompt}] - text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) - input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() - try: - _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) - ars.append(ar) - print(f' AR={ar:.2f} | {prompt[:60]}') - except Exception as e: - print(f' ERROR | {prompt[:60]}... | {e}') - -if ars: - avg_ar = sum(ars) / len(ars) - print(f'\n==== DFlash AR Results ====') - print(f'Samples: {len(ars)}') - print(f'Average AR: {avg_ar:.4f}') - print(f'Min AR: {min(ars):.4f}') - print(f'Max AR: {max(ars):.4f}') -else: - print('No AR results collected') -" diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index f30dd7292c..2bce9c9ded 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -15,17 +15,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -# DFlash online training + AR validation script for the ModelOpt Launcher. -# Trains a DFlash draft model alongside the frozen target model, -# then evaluates acceptance rate on MT-Bench. +# DFlash online training script for the ModelOpt Launcher. +# Trains a DFlash draft model using accelerate launch + main.py --config. # -# Required env vars: -# HF_MODEL_CKPT — path to the target HuggingFace model +# All training config comes from the YAML recipe (--config) and OmegaConf overrides. +# All args are passed directly to main.py (--config + key=value overrides). # -# Optional env vars: -# NUM_AR_SAMPLES — number of MT-Bench samples for AR validation (default: 20, 0 to skip) +# Multi-node env vars (set by Slurm or user): +# NUM_NODES — number of nodes (default: 1) +# HEAD_NODE_IP — head node IP (auto-detected if not set) # -# All other args are passed through to launch_train.sh. +# Usage from YAML: +# script: common/dflash/online_training.sh +# args: +# - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml +# - model.model_name_or_path=/hf-local/Qwen/Qwen3-8B +# - data.data_path=/path/to/data.jsonl +# - training.output_dir=/scratchspace/dflash +# environment: +# - NUM_NODES: "8" SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" source ${SCRIPT_DIR}/../service_utils.sh @@ -39,17 +47,14 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR # Auto-detect head node IP for multi-node training -if [ -z "$HEAD_NODE_IP" ]; then - # Method 1: scontrol (works outside container) +NUM_NODES=${NUM_NODES:-1} +if [ -z "$HEAD_NODE_IP" ] && [[ "$NUM_NODES" != "1" ]]; then HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) - # Method 2: SLURM_LAUNCH_NODE_IPADDR (some Slurm versions) HEAD_NODE_IP=${HEAD_NODE_IP:-$SLURM_LAUNCH_NODE_IPADDR} - # Method 3: Parse SLURM_NODELIST and resolve via Python if [ -z "$HEAD_NODE_IP" ] && [ -n "$SLURM_JOB_NODELIST" ]; then HEAD_NODE_IP=$(python3 -c " import socket, re, os nl = os.environ.get('SLURM_JOB_NODELIST', '') -# Extract first hostname: 'node[001-002]' -> 'node001', 'node001,node002' -> 'node001' m = re.match(r'([a-zA-Z0-9-]+?)(?:\[(\d+))?', nl) if m: host = m.group(1) + (m.group(2) or '') @@ -59,7 +64,6 @@ if m: print(host) " 2>/dev/null) fi - # Method 4: Use rank 0's hostname if [ -z "$HEAD_NODE_IP" ] && [ "${SLURM_PROCID:-0}" = "0" ]; then HEAD_NODE_IP=$(hostname -I 2>/dev/null | awk '{print $1}') fi @@ -67,189 +71,28 @@ if m: echo "Auto-detected HEAD_NODE_IP: ${HEAD_NODE_IP}" fi -# Parse DFlash-specific args from the command line for AR validation -DFLASH_BLOCK_SIZE=16 -DFLASH_NUM_LAYERS=5 -DFLASH_MASK_TOKEN_ID="" -OUTPUT_DIR="" -for arg in "$@"; do - case "$arg" in - --dflash_block_size) next_is_block_size=1 ;; - --dflash_num_layers) next_is_num_layers=1 ;; - --dflash_mask_token_id) next_is_mask_id=1 ;; - --output_dir) next_is_output=1 ;; - *) - if [ "$next_is_block_size" = "1" ]; then DFLASH_BLOCK_SIZE="$arg"; next_is_block_size=0; fi - if [ "$next_is_num_layers" = "1" ]; then DFLASH_NUM_LAYERS="$arg"; next_is_num_layers=0; fi - if [ "$next_is_mask_id" = "1" ]; then DFLASH_MASK_TOKEN_ID="$arg"; next_is_mask_id=0; fi - if [ "$next_is_output" = "1" ]; then OUTPUT_DIR="$arg"; next_is_output=0; fi - ;; - esac -done - -# Step 1: Training -# Build config overrides from CLI args and env vars -CONFIG_FILE=${DFLASH_CONFIG:-modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml} - -# Parse --key value pairs into OmegaConf dotlist overrides -OVERRIDES="model.model_name_or_path=${HF_MODEL_CKPT}" -while [ $# -gt 0 ]; do - case "$1" in - --data) shift; OVERRIDES="$OVERRIDES data.data_path=$1" ;; - --output_dir) shift; OVERRIDES="$OVERRIDES training.output_dir=$1"; OUTPUT_DIR="$1" ;; - --num_epochs) shift; OVERRIDES="$OVERRIDES training.num_train_epochs=$1" ;; - --lr) shift; OVERRIDES="$OVERRIDES training.learning_rate=$1" ;; - --training_seq_len) shift; OVERRIDES="$OVERRIDES training.training_seq_len=$1" ;; - --save_steps) shift; OVERRIDES="$OVERRIDES training.save_steps=$1" ;; - --log_steps) shift; OVERRIDES="$OVERRIDES training.logging_steps=$1" ;; - --disable_tqdm) shift; OVERRIDES="$OVERRIDES training.disable_tqdm=$1" ;; - --ar_validate_steps) shift; OVERRIDES="$OVERRIDES training.ar_validate_steps=$1" ;; - --dflash_block_size) shift; OVERRIDES="$OVERRIDES dflash.dflash_block_size=$1"; DFLASH_BLOCK_SIZE="$1" ;; - --dflash_num_layers) shift; OVERRIDES="$OVERRIDES dflash.dflash_architecture_config.num_hidden_layers=$1"; DFLASH_NUM_LAYERS="$1" ;; - --dflash_mask_token_id) shift; OVERRIDES="$OVERRIDES dflash.dflash_architecture_config.mask_token_id=$1"; DFLASH_MASK_TOKEN_ID="$1" ;; - --dflash_num_anchors) shift; OVERRIDES="$OVERRIDES dflash.dflash_num_anchors=$1" ;; - --dflash_loss_decay_gamma) shift; OVERRIDES="$OVERRIDES dflash.dflash_loss_decay_factor=$1" ;; - --num_nodes) shift; NUM_NODES="$1" ;; - --config) shift; CONFIG_FILE="$1" ;; - *) ;; - esac - shift -done - -# Add tensorboard logging dir -if [ -n "$OUTPUT_DIR" ]; then - OVERRIDES="$OVERRIDES training.logging_dir=${OUTPUT_DIR}/tensorboard" -fi - -bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ - --config ${CONFIG_FILE} \ - --num_nodes ${NUM_NODES:-1} \ - --head_node_ip ${HEAD_NODE_IP:-} \ - ${OVERRIDES} - -# Step 2: AR Validation -NUM_AR_SAMPLES=${NUM_AR_SAMPLES:-20} -if [ "${NUM_AR_SAMPLES}" = "0" ]; then - echo "Skipping AR validation (NUM_AR_SAMPLES=0)" - exit 0 -fi - -if [ -z "$OUTPUT_DIR" ]; then - echo "WARNING: --output_dir not found in args, skipping export and AR validation" - exit 0 -fi - -# Step 2: Export checkpoint to z-lab HF format -EXPORT_DIR=${OUTPUT_DIR}/export -echo "" -echo "=== Exporting DFlash checkpoint ===" -echo "Source: ${OUTPUT_DIR}" -echo "Export: ${EXPORT_DIR}" - -python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ - --model_path ${OUTPUT_DIR} \ - --export_path ${EXPORT_DIR} \ - || echo "WARNING: Export failed, continuing with AR validation" - -echo "" -echo "Export contents:" -ls -la ${EXPORT_DIR}/ 2>/dev/null || echo "No export dir" - -# Step 3: AR Validation -# Build mask_token_id config -if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then - MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +# Build accelerate launch command +MAIN_PY=modules/Model-Optimizer/examples/speculative_decoding/main.py + +if [[ "$NUM_NODES" != "1" ]]; then + GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} + TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) + echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" + MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ + --num_machines $NUM_NODES \ + --machine_rank $SLURM_PROCID \ + --rdzv_backend c10d \ + --main_process_ip $HEAD_NODE_IP \ + --main_process_port 29500" else - MASK_ARG="" + TOTAL_GPU=$(python3 -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (single node)" + MULTI_NODE_ARGS="" fi -echo "" -echo "=== DFlash AR Validation ===" -echo "Target model: ${HF_MODEL_CKPT}" -# Prefer exported checkpoint (no prefix), fall back to training output (with prefix) -if [ -f "${EXPORT_DIR}/model.safetensors" ]; then - AR_CKPT=${EXPORT_DIR} - echo "Using exported checkpoint: ${AR_CKPT}" -else - AR_CKPT=${OUTPUT_DIR} - echo "Using training checkpoint: ${AR_CKPT}" -fi -echo "Block size: ${DFLASH_BLOCK_SIZE}" -echo "Draft layers: ${DFLASH_NUM_LAYERS}" -echo "Samples: ${NUM_AR_SAMPLES}" - -CUDA_VISIBLE_DEVICES=0 python3 -c " -import torch -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from modelopt.torch.speculative.plugins.transformers import HFARValidation -import modelopt.torch.opt as mto -import modelopt.torch.speculative as mtsp - -mto.enable_huggingface_checkpointing() - -model = AutoModelForCausalLM.from_pretrained( - '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True -) -tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) - -config = { - 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, - 'dflash_architecture_config': { - 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, - ${MASK_ARG} - }, - 'dflash_use_torch_compile': False, -} -mtsp.convert(model, [('dflash', config)]) - -# Load trained DFlash weights -import glob -from safetensors.torch import load_file -ckpt_files = sorted(glob.glob('${AR_CKPT}/model*.safetensors')) -if ckpt_files: - state = {} - for f in ckpt_files: - state.update(load_file(f)) - dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} - if dflash_keys: - model.load_state_dict(dflash_keys, strict=False) - print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') - else: - result = model.dflash_module.load_state_dict(state, strict=False) - loaded = len(state) - len(result.unexpected_keys) - print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') -else: - print('WARNING: No checkpoint files found, using random weights') +export TOKENIZERS_PARALLELISM=False -model.eval() -validator = HFARValidation(model, tokenizer) - -ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] -num_samples = min(${NUM_AR_SAMPLES}, len(ds)) - -ars = [] -for i in range(num_samples): - prompt = ds[i]['prompt'][0] - chat = [{'role': 'user', 'content': prompt}] - text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) - input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() - try: - _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) - ars.append(ar) - print(f' AR={ar:.2f} | {prompt[:60]}') - except Exception as e: - print(f' ERROR | {prompt[:60]}... | {e}') - -if ars: - avg_ar = sum(ars) / len(ars) - print(f'\n==== DFlash AR Results ====') - print(f'Samples: {len(ars)}') - print(f'Average AR: {avg_ar:.4f}') - print(f'Min AR: {min(ars):.4f}') - print(f'Max AR: {max(ars):.4f}') -else: - print('No AR results collected') -" - -################################################################################################### +set -x +start_time=$(date +%s) +accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS $MAIN_PY "$@" +echo "Training time: $(( $(date +%s) - start_time )) seconds" diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml index 73d608e492..4da2e35b61 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -1,8 +1,5 @@ # DFlash online speculative decoding training for Qwen3-8B. # -# Trains a DFlash draft model (block diffusion) using the frozen target model -# to extract multi-layer hidden states on the fly, then evaluates AR on MT-Bench. -# # 2-step pipeline: # task_0: Online DFlash training (8 nodes, 64 GPUs) # task_1: MT-Bench per-category AR evaluation (1 GPU) @@ -15,47 +12,47 @@ job_name: Qwen3-8B_DFlash_online pipeline: + global_vars: + hf_model: /hf-local/Qwen/Qwen3-8B + # Step 1: Online DFlash training task_0: script: common/dflash/online_training.sh args: - - --data /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/default.jsonl - - --output_dir /scratchspace/dflash_bs16 - - --num_epochs 5 - - --lr 6e-4 - - --training_seq_len 4096 - - --save_steps 5000 - - --log_steps 1000 - - --disable_tqdm True - - --ar_validate_steps 0 - - --dflash_block_size 16 - - --dflash_num_layers 5 - - --dflash_mask_token_id 151669 - - --dflash_num_anchors 512 - - --dflash_loss_decay_gamma 7 - - --num_nodes 8 + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - data.data_path=/hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-1K.jsonl + - training.output_dir=/scratchspace/dflash_bs16 + - training.num_train_epochs=1 + - training.training_seq_len=4096 + - training.save_steps=5000 + - training.logging_steps=1000 + - training.disable_tqdm=true + - training.answer_only_loss=true + - dflash.dflash_block_size=16 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=7 + - dflash.dflash_architecture_config.mask_token_id=151669 + - dflash.dflash_architecture_config.num_hidden_layers=5 environment: - - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B - - NUM_AR_SAMPLES: "0" slurm_config: _factory_: "slurm_factory" - nodes: 8 + nodes: 1 ntasks_per_node: 1 gpus_per_node: 8 - # Step 2: MT-Bench per-category AR evaluation (ModelOpt online validation) + # Step 2: MT-Bench per-category AR evaluation task_1: script: common/dflash/ar_eval_mtbench.sh args: - --ckpt_dir /scratchspace/dflash_bs16 - --block_size 16 - --num_layers 5 - - --mask_token_id 151669 - --osl 512 - --steps 15 - --online true environment: - - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - HF_MODEL_CKPT: <> slurm_config: _factory_: "slurm_factory" nodes: 1 From c78c1c83454af523495f7093733107368bea9b95 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 8 Apr 2026 17:07:21 -0700 Subject: [PATCH 04/24] fix: revert DFlash-specific changes, add sliding window docs - Revert on_step_end AR validation to upstream (DFlash deadlocks with DDP) - Revert checkpoint resume to upstream (load from checkpoint directly) - Keep: answer_only_loss pass-through, accuracy console/tensorboard logging - Document sliding window support in README and recipe YAML Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/README.md | 4 ++ examples/speculative_decoding/eagle_utils.py | 43 +++++-------------- examples/speculative_decoding/main.py | 14 +----- .../general/speculative_decoding/dflash.yaml | 2 + 4 files changed, 18 insertions(+), 45 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index a5e90640c1..5072214f35 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -378,6 +378,10 @@ uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes | `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | | `training.answer_only_loss` | false | Mask loss on non-assistant tokens | +Qwen3 sliding window attention is automatically supported — draft layers inherit +`layer_types` and `sliding_window` from the config, matching the target model's +attention pattern. + ### Export ```bash diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 2b01239cb4..fed88401a0 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -245,46 +245,23 @@ def on_log(self, args, state, control, **kwargs): return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically (single-GPU only). - - AR validation with DDP is not supported because pseudo_speculative_generate - runs only on rank 0 while other ranks deadlock waiting for collective ops. - When world_size > 1, AR validation is skipped with a one-time warning. - Use post-training AR validation instead (online_training.sh runs it after training). - """ + """Run AR validation periodically, if available.""" if self.ar_validate_steps <= 0: return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: - if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: - if not hasattr(self, "_ar_ddp_warned"): - self._ar_ddp_warned = True - print_rank_0( - "=== WARNING === AR validation during training is not supported with " - "DDP (world_size > 1). Skipping. Use post-training AR validation." - ) - return control - - model = kwargs["model"] - raw_model = model.module if hasattr(model, "module") else model - was_training = raw_model.training - raw_model.eval() print_rank_0("Running AR validation...") try: - with torch.no_grad(): - ars = validate_ar( - model=raw_model, - tokenizer=kwargs["processing_class"], - ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"], - device=next(raw_model.parameters()).device, - num_samples=8, - ) + ars = validate_ar( + model=kwargs["model"], + tokenizer=kwargs["processing_class"], + ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], + device=kwargs["model"].device, + ) print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb: + if hasattr(wandb, "init") and is_master(): wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - except Exception as e: - print_rank_0(f"AR validation failed: {e}") - if was_training: - raw_model.train() + except Exception: + print_rank_0("AR validation not available.") return control diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index c029b83bd3..047a70d864 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -213,22 +213,12 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: - # Prefer top-level output_dir, fall back to checkpoint subdir - model_load_path = training_args.output_dir - if not os.path.isfile(os.path.join(model_load_path, "model.safetensors")): - model_load_path = checkpoint - print_rank_0( - f"No model.safetensors in {training_args.output_dir}, " - f"loading from checkpoint: {model_load_path}" - ) with patch_transformers5_params_loading(): model = load_vlm_or_llm( - model_load_path, - torch_dtype="auto", - trust_remote_code=model_args.trust_remote_code, + checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code ) tokenizer = transformers.AutoTokenizer.from_pretrained( - model_load_path, trust_remote_code=model_args.trust_remote_code + checkpoint, trust_remote_code=model_args.trust_remote_code ) else: # To avoid OOM for large models, we load and convert model on CPU first. diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml index 83fd54fb20..90d161e95d 100644 --- a/modelopt_recipes/general/speculative_decoding/dflash.yaml +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -51,3 +51,5 @@ dflash: dflash_loss_decay_factor: 4.0 dflash_architecture_config: num_hidden_layers: 5 + # mask_token_id: auto-detected from model vocab (override for specific models) + # sliding_window and layer_types are inherited from base model config automatically From aa6165a48f3aefe98c9ae161370e0a1c6e71706e Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 8 Apr 2026 19:46:59 -0700 Subject: [PATCH 05/24] refactor: consolidate docs, simplify eval, online GT as default - Consolidate dflash_results.md into comprehensive dflash.md - Simplify ar_validate.py: online GT as default, per-category support - Simplify ar_eval_mtbench.sh: calls ar_validate.py instead of inline Python - Error on unsupported mask_token_id instead of falling back to pad/eos - Add sliding window, FP8/NVFP4, offline training, MLA docs Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/README.md | 2 +- examples/speculative_decoding/doc/dflash.md | 206 +++++++++++++++++ .../doc/dflash_results.md | 85 ------- .../scripts/ar_validate.py | 107 +++++---- .../torch/speculative/plugins/hf_dflash.py | 71 ++++-- .../launcher/common/dflash/ar_eval_mtbench.sh | 213 +++--------------- .../Qwen/Qwen3-8B/hf_online_dflash.yaml | 5 - 7 files changed, 354 insertions(+), 335 deletions(-) create mode 100644 examples/speculative_decoding/doc/dflash.md delete mode 100644 examples/speculative_decoding/doc/dflash_results.md diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 5072214f35..673b2db0e0 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -392,4 +392,4 @@ python scripts/export_hf_checkpoint.py \ ### Results -See [doc/dflash_results.md](doc/dflash_results.md) for benchmark results on Qwen3-8B. +See [doc/dflash.md](doc/dflash.md) for design details, benchmark results, and open items. diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md new file mode 100644 index 0000000000..e1e7007d44 --- /dev/null +++ b/examples/speculative_decoding/doc/dflash.md @@ -0,0 +1,206 @@ +# DFlash — Block Diffusion for Speculative Decoding + +DFlash predicts an entire block of tokens in a single forward pass using masked parallel +prediction with KV injection from the target model's hidden states. + +Reference: [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) | +[SpecForge](https://github.com/sgl-project/SpecForge) | +[z-lab](https://github.com/z-lab/dflash) + +## Architecture + +``` +Target Model (frozen) + │ + ├─ hidden_states[layer 1, 9, 17, 25, 33] ──► concat ──► FC + RMSNorm ──► target_hidden + │ │ + │ K/V injection + │ │ + └─ embed([anchor, mask, mask, ...]) ──► noise_embedding ──► DFlash Decoder (5 layers) + │ + lm_head ──► draft tokens +``` + +**Key components:** +- **Feature Fusion**: Multi-layer hidden states → Linear(num_layers × hidden_size, hidden_size) + RMSNorm +- **KV Injection**: In each draft decoder layer, K/V = concat(k_proj(target_hidden), k_proj(noise)) + with QK-norm. Q comes from noise only. +- **Parallel Drafting**: Position 0 is the anchor (known token), positions 1..B-1 are mask tokens + predicted in parallel. Bidirectional attention within the block. +- **Random Anchor Sampling**: During training, anchor positions are sampled randomly from + valid (assistant response) positions, not uniformly spaced. + +**Draft model components** (Qwen3-based): +- `Qwen3MLP`, `Qwen3RMSNorm`, `Qwen3RotaryEmbedding` from transformers +- Sliding window attention supported via `config.layer_types` +- Independent of target model architecture + +## Training + +### Quick Start + +```bash +uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +``` + +### Recipe + +See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../modelopt_recipes/general/speculative_decoding/dflash.yaml) + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `dflash.dflash_block_size` | 8 | Block size for parallel prediction | +| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample | +| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables) | +| `dflash.dflash_self_logit_distillation` | true | Logit distillation from target | +| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | +| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | +| `training.answer_only_loss` | false | Mask loss on non-assistant tokens | + +### Loss Decay + +The exponential decay factor (gamma) weights early block positions higher than later ones. +If position 0 in a block is wrong, all subsequent positions are rejected in speculative +decoding. Decay aligns the training loss with what matters for acceptance rate. + +``` +weight[k] = exp(-k / gamma) for k = 0..B-1 +``` + +Paper recommendation: gamma=7 for block_size=16, gamma=4 for block_size=8. + +### Checkpoint Resume + +DFlash supports checkpoint resume transparently. The `DFlashModule._apply()` method +handles meta-tensor rotary buffers that arise during ModelOpt checkpoint restore — no +special resume logic needed in the training script. + +### Export + +```bash +python scripts/export_hf_checkpoint.py \ + --model_path /path/to/training/output \ + --export_path /path/to/exported/model +``` + +Exports to z-lab compatible HF format (`config.json` + `model.safetensors`). + +## Results (Qwen3-8B) + +Trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples), 64 GPUs, 10 epochs. + +### Training Configuration + +| Parameter | Value | +|-----------|-------| +| Block Size | 8 | +| Sequence Length | 4096 | +| Anchors | 512 | +| Loss | KD + decay (gamma=4) | +| Total Steps | 306,620 | +| Final Per-Token Acc | 67.0% | + +### MT-Bench Per-Category AR (Online Validation, osl=512) + +| Category | 80K | 150K | 306K | +|----------|-----|------|------| +| math | 5.44 | 5.54 | **5.52** | +| extraction | 4.81 | 4.82 | **4.88** | +| coding | 4.40 | 4.53 | **4.60** | +| reasoning | 4.34 | 4.41 | **4.44** | +| stem | 4.05 | 4.15 | **4.17** | +| writing | 3.76 | 3.79 | **3.84** | +| roleplay | 3.58 | 3.73 | **3.78** | +| humanities | 3.55 | 3.62 | **3.65** | +| **ALL** | **4.24** | **4.32** | **4.36** | + +### Comparison with z-lab/Qwen3-8B-DFlash-b16 + +**ModelOpt eval (online validation, osl=512):** + +| Dataset | z-lab | ModelOpt | Diff | +|---------|-------|----------|------| +| gsm8k | 4.10 | **5.19** | **+1.09** | +| MT-Bench | 3.58 | **4.36** | **+0.78** | + +**z-lab official eval (dflash.benchmark, osl=512):** + +| Dataset | z-lab | ModelOpt | Diff | +|---------|-------|----------|------| +| gsm8k | **5.00** | 4.08 | -0.92 | +| MT-Bench | **3.28** | 2.99 | -0.29 | + +> z-lab trained with block_size=16; ModelOpt trained with block_size=8. + +### Evaluation Methods + +| Method | Description | +|--------|-------------| +| **Fixed GT** | Pre-compute greedy ground truth, check draft against it | +| **Online GT** | Recompute ground truth after each accepted draft (context-dependent) | +| **z-lab official** | Actual speculative decoding with draft KV cache | + +Online GT is more accurate than Fixed GT (~+1.0 AR) because speculative decoding +acceptance depends on context-dependent verification, not a fixed reference sequence. + +### Key Findings + +| Finding | Evidence | +|---------|----------| +| Loss decay boosts AR | +0.12 AR at 55K (gamma=7, bs16); consistent across checkpoints | +| Longer sequences help | seq=4096 vs 512: +0.49 AR on AA-Synthetic | +| Online validation essential | Fixed GT underestimates by ~1.0 AR | +| Forward pass identical to z-lab | Max diff 0.5 (bf16); 6/7 draft tokens match | +| sdpa vs flash_attn: negligible | AR 3.31 vs 3.31; hidden states identical | + +## Open Items + +### Offline Training + +Online training requires the full target model in GPU memory alongside the draft model. +Offline training would pre-compute target hidden states and train the draft model separately. + +**Challenge**: DFlash uses random anchor sampling over full sequences, requiring hidden states +at ALL positions. For Qwen3-8B with 5 target layers and seq_len=4096, this is ~160MB per sample +in bf16. With 2M samples, full pre-computation would require ~320TB — not feasible. + +**Potential approaches:** +- Pre-sample anchor positions and store only relevant slices (limits randomness) +- Stream hidden states from disk with chunked loading +- Hybrid: quantized base model on CPU computes hidden states on-the-fly, draft on GPU +- Logit distillation adds another dimension: teacher logits at anchor+k-1 positions + need `[seq_len, vocab_size]` per sample (~600MB in bf16) + +### z-lab Eval Gap + +ModelOpt eval (online GT) gives higher AR than z-lab's official eval on our checkpoint +(5.19 vs 4.08 on gsm8k). The gap is likely from: +- z-lab uses draft KV cache (accumulates context across blocks); our eval re-runs from scratch +- z-lab's `acceptance_length + 1` counting (minimum 1 per step) +- `rope_theta` mismatch in exported config (was 10000 instead of 1000000 — now fixed) + +### Model Support Expansion + +Currently supports Qwen3 draft architecture. See `hf_dflash.py` module docstring for +instructions on adding: +- **Qwen3MoE**: Replace MLP with `Qwen3MoeMLP` via config flag +- **MLA (DeepseekV3/Kimi-K2)**: Requires MLA-aware KV injection with compressed K/V + +### FP8 / NVFP4 Quantization + +The DFlash export pipeline supports quantized checkpoints via ModelOpt PTQ, following +the same flow as EAGLE3: + +1. Train draft model (bf16) +2. Apply PTQ: `mtq.quantize(model, quant_cfg)` with `FP8_DEFAULT_CFG` or `NVFP4_DEFAULT_CFG` +3. Export: `export_hf_checkpoint.py` auto-detects quantization and writes scales + `quantization_config` + +The exporter's `has_quant_opt()` check and `_export_transformers_checkpoint()` handle +quantized weights transparently. No DFlash-specific quantization code is needed. + +TODO: Add a quantization recipe/script and validate FP8/NVFP4 AR impact. + +### Docker Local Testing + +The launcher example currently requires Slurm cluster access. A local Docker example +with `hf_local=` path mapping would enable development without cluster access. diff --git a/examples/speculative_decoding/doc/dflash_results.md b/examples/speculative_decoding/doc/dflash_results.md deleted file mode 100644 index a6a4f2d252..0000000000 --- a/examples/speculative_decoding/doc/dflash_results.md +++ /dev/null @@ -1,85 +0,0 @@ -# DFlash Block Diffusion — ModelOpt Training Results - -Qwen3-8B target model, trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples) - -## Key Metrics - -| Benchmark | Acceptance Rate | -|-----------|----------------| -| **gsm8k** | **5.19** | -| **MT-Bench** | **4.36** | - -> Online validation, block_size=8, osl=512 - -## Training Configuration - -| Parameter | Value | -|-----------|-------| -| Target Model | Qwen3-8B | -| Draft Layers | 5 | -| Block Size | 8 | -| Sequence Length | 4096 | -| Anchors per Sample | 512 | -| Loss | KD (logit distillation) + exponential decay (gamma=4) | -| Learning Rate | 6e-4 (linear decay) | -| Epochs | 10 | -| GPUs | 64 (8 nodes x 8 H100) | -| Total Steps | 306,620 | -| Final Loss | 1.129 | -| Final Per-Token Acc | 67.0% | - -## MT-Bench Per-Category AR (Online Validation) - -80 prompts, block_size=8, osl=512, steps=7 - -| Category | 80K | 150K | 306K (final) | -|----------|-----|------|-------------| -| math | 5.44 | 5.54 | **5.52** | -| extraction | 4.81 | 4.82 | **4.88** | -| coding | 4.40 | 4.53 | **4.60** | -| reasoning | 4.34 | 4.41 | **4.44** | -| stem | 4.05 | 4.15 | **4.17** | -| writing | 3.76 | 3.79 | **3.84** | -| roleplay | 3.58 | 3.73 | **3.78** | -| humanities | 3.55 | 3.62 | **3.65** | -| **ALL** | **4.24** | **4.32** | **4.36** | - -## Comparison with z-lab/Qwen3-8B-DFlash-b16 - -### ModelOpt Eval (online validation, osl=512) - -| Dataset | z-lab | ModelOpt (306K) | Diff | -|---------|-------|-----------------|------| -| gsm8k | 4.10 | **5.19** | **+1.09** | -| MT-Bench | 3.58 | **4.36** | **+0.78** | - -### z-lab Official Eval (dflash.benchmark, osl=512) - -| Dataset | z-lab | ModelOpt (306K) | Diff | -|---------|-------|-----------------|------| -| gsm8k | **5.00** | 4.08 | -0.92 | -| MT-Bench | **3.28** | 2.99 | -0.29 | - -> z-lab model trained with block_size=16. ModelOpt trained with block_size=8. - -## Evaluation Method Impact (gsm8k) - -| Eval Method | z-lab checkpoint | ModelOpt (306K) | -|-------------|-----------------|-----------------| -| Fixed GT (ModelOpt eval) | 2.95 | 4.23 | -| Online GT (ModelOpt eval) | 4.10 | **5.19** | -| z-lab official eval | **5.00** | 4.08 | - -- **Fixed GT**: pre-compute greedy ground truth, check draft against it. -- **Online GT**: recompute ground truth after each accepted draft (context-dependent). -- **z-lab official**: actual speculative decoding with draft KV cache. - -## Key Findings - -| Finding | Evidence | -|---------|----------| -| Loss decay boosts AR | +0.12 AR at 55K steps (gamma=7, bs16); consistent across all checkpoints | -| Longer sequences help | seq=4096 vs 512: +0.49 AR on AA-Synthetic at same checkpoint | -| Online validation essential | Fixed GT underestimates by ~1.0 AR; context-dependent GT matches actual spec-decode | -| Forward pass identical to z-lab | Max diff 0.5 (bf16 noise) on same mask_token_id; 6/7 draft tokens match | -| sdpa vs flash_attn: negligible | Overall AR 3.31 vs 3.31; hidden states identical, logits differ <2% | diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index 1ad7bec409..7e8e661e0c 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -13,7 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""AR validation for speculative decoding models (EAGLE3, DFlash, Medusa). + +Supports per-category MT-Bench evaluation and online (context-dependent) validation. +""" + import argparse +from collections import defaultdict from accelerate import Accelerator from datasets import load_dataset @@ -27,52 +33,66 @@ mto.enable_huggingface_checkpointing() -def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None): +def validate_ar( + model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None, +): + """Validate acceptance rate on MT-Bench prompts using online validation. + + Online validation recomputes ground truth after each accepted draft token + (context-dependent), matching actual speculative decoding behavior. + + Args: + model: Speculative decoding model (EAGLE3, DFlash, etc.) + tokenizer: Tokenizer for the model. + ds: MT-Bench dataset (HuggingFace dataset with 'prompt' and optional 'category'). + steps: Number of draft tokens per speculative step. + osl: Output sequence length. + num_samples: Max number of samples to evaluate. + device: Device to run on. + + Returns: + List of (category, ar) tuples. + """ validator = HFARValidation(model, tokenizer) num_samples = min(num_samples, len(ds)) - ars = [] + results = [] for i in tqdm(range(num_samples), desc="Validating AR"): prompt = ds[i]["prompt"][0] - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - # Apply chat template to the prompt, continuing with assistant response + category = ds[i].get("category", "unknown") if hasattr(tokenizer, "apply_chat_template"): - chat_messages = [ - {"role": "user", "content": prompt}, - ] + chat_messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) - input_ids = tokenizer(prompt, return_tensors="pt").input_ids + input_ids = tokenizer(prompt, return_tensors="pt").input_ids if device: input_ids = input_ids.to(device) - # validate AR - _, ar = validator.validate(osl, input_ids=input_ids, steps=steps) - ars.append(ar) - return ars + try: + _, ar = validator.validate_online(osl, input_ids=input_ids, steps=steps) + results.append((category, ar)) + except Exception: + pass + return results def main(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="AR validation for speculative decoding models.") parser.add_argument("--model_path", type=str, required=True, help="Path to model directory") parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code") - parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation") - parser.add_argument( - "--osl", type=int, default=32, help="Output sequence length for AR validation" - ) - parser.add_argument( - "--num_samples", type=int, default=80, help="Number of MT-Bench samples to use" - ) + parser.add_argument("--steps", type=int, default=3, help="Draft tokens per step") + parser.add_argument("--osl", type=int, default=32, help="Output sequence length") + parser.add_argument("--num_samples", type=int, default=80, help="Number of samples") + parser.add_argument("--per_category", action="store_true", help="Report per-category AR") parser.add_argument( "--ar_lower_bound", type=float, default=None, - help="AR lower bound for validation. If provided, will throw error if AR is below threshold.", + help="Error if AR is below this threshold.", ) args = parser.parse_args() accelerator = Accelerator() - # Load model and tokenizer model = load_vlm_or_llm( args.model_path, device_map="auto", trust_remote_code=args.trust_remote_code ) @@ -82,26 +102,37 @@ def main(): model.eval() model = accelerator.prepare(model) - # Load MT-Bench prompts from HuggingFace ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] - ars = validate_ar( - model, tokenizer, ds, args.steps, args.osl, args.num_samples, accelerator.device + results = validate_ar( + model, + tokenizer, + ds, + args.steps, + args.osl, + args.num_samples, + accelerator.device, ) - # Optionally, throw error if AR is below lower bound - if args.ar_lower_bound: - mean_ar = sum(ars) / len(ars) - if mean_ar < args.ar_lower_bound: + + if results and accelerator.is_main_process: + all_ars = [ar for _, ar in results] + avg_ar = sum(all_ars) / len(all_ars) + print(f"\n==== AR Validation Results (osl={args.osl}, steps={args.steps}) ====") + + if args.per_category: + cat_ars = defaultdict(list) + for cat, ar in results: + cat_ars[cat].append(ar) + for cat in sorted(cat_ars): + cat_avg = sum(cat_ars[cat]) / len(cat_ars[cat]) + print(f" {cat:>12}: {cat_avg:.4f}") + + print(f" {'ALL':>12}: {avg_ar:.4f}") + print(f" Samples: {len(results)}") + + if args.ar_lower_bound and avg_ar < args.ar_lower_bound: raise ValueError( - f"AR is below lower bound {args.ar_lower_bound}. Mean AR: {mean_ar:.4f}" + f"AR {avg_ar:.4f} is below lower bound {args.ar_lower_bound}." ) - # Print results - if ars and accelerator.is_main_process: - avg_ar = sum(ars) / len(ars) - print("\n==== AR Validation Results on MT-Bench ====") - print(f"Number of samples: {len(ars)}") - print(f"Output Sequence Length: {args.osl}") - print(f"Steps: {args.steps}") - print(f"Average AR: {avg_ar:.4f}") if __name__ == "__main__": diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 5ec932e987..c26d800bc9 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -413,18 +413,12 @@ def _auto_detect_mask_token_id(base_config): if vocab_size >= 128256: # Llama3 vocab size return 128002 # <|reserved_special_token_0|> - # Generic: try pad_token_id, then eos - pad_id = getattr(base_config, "pad_token_id", None) - eos_id = getattr(base_config, "eos_token_id", None) - if isinstance(eos_id, list): - eos_id = eos_id[0] - - # Prefer pad over eos (pad is less likely to interfere) - if pad_id is not None and pad_id != eos_id: - return pad_id - - # Last resort - return eos_id or 0 + # No suitable mask token found — user must provide one + raise ValueError( + f"Cannot auto-detect mask_token_id for model_type='{model_type}'. " + f"Please set dflash_architecture_config.mask_token_id explicitly in your config. " + f"The mask token should be an unused special token (not eos or pad)." + ) def _find_base_model_parts(self): """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.""" @@ -507,8 +501,7 @@ def modify(self, config): # 2. Auto-detect from model vocabulary: # - Qwen3/3.5: built-in [MASK] token # - Llama3: reserved_special_token_0 (128002) - # - Others: tokenizer.mask_token_id - # 3. Fallback to pad_token_id or eos_token_id (suboptimal) + # 3. Error — user must provide mask_token_id for unsupported models mask_id = config.dflash_architecture_config.get("mask_token_id", None) if mask_id is None: mask_id = self._auto_detect_mask_token_id(base_config) @@ -810,15 +803,55 @@ def forward( @torch.no_grad() def pseudo_speculative_generate(self, input_ids, steps=1): - """Generate draft tokens using one DFlash block. + """Generate draft tokens using one DFlash block for AR validation. + + This method implements a single speculative decoding step: + + 1. **Base model forward**: Run the full target model on ``input_ids`` to get: + - ``base_token``: greedy next token (argmax of last position logits) + - ``hidden_states``: intermediate hidden states from target layers + + 2. **Extract target hidden states**: Concatenate hidden states from + ``target_layer_ids`` (e.g., layers [1, 9, 17, 25, 33] for 5-layer draft). + Shape: ``[B, seq_len, num_layers * hidden_size]``. + + 3. **Build block input**: Create a block of ``block_size`` tokens where: + - Position 0 = ``base_token`` (the anchor/known token) + - Positions 1..block_size-1 = ``mask_token_id`` (unknown, to be predicted) + Embed this block via the base model's embedding layer. + + 4. **Position IDs**: Context positions ``[0..seq_len-1]`` followed by block + positions ``[seq_len..seq_len+block_size-1]``. The draft model's attention + uses RoPE on these positions so Q (block only) attends to K (context + block) + with correct relative position encoding. + + 5. **Draft forward**: Run ``DFlashModule`` with: + - ``noise_embedding``: embedded block tokens + - ``target_hidden``: extracted hidden states from step 2 + - ``position_ids``: context + block positions + - ``attention_mask=None``: no mask at inference (all positions attend freely) + The draft model's KV injection concatenates projected target_hidden as K/V + with the block's own K/V, enabling the draft to "see" the target's context. + + 6. **Decode**: Apply ``lm_head`` to draft hidden states at positions 1..block_size-1 + (skip position 0 which is the known anchor). Argmax gives draft tokens. + + 7. **Return**: ``(base_token, draft_tokens[:steps])`` — base token is always + returned; draft tokens are truncated to ``steps`` (default: block_size-1). + + Note: + This method re-runs the full target model from scratch on each call + (no KV cache). For AR validation, it is called repeatedly with growing + ``input_ids`` by ``AcceptanceRateValidation.validate()``. The ``steps`` + parameter should be set to ``block_size - 1`` for full block evaluation. - DFlash generates block_size-1 draft tokens in a single forward pass. - The `steps` parameter is used as the number of tokens to return - (capped at block_size-1). + Args: + input_ids: Input token IDs [B, seq_len]. + steps: Number of draft tokens to return (capped at block_size-1). Returns: base_token: Next token from base model [B, 1]. - draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None. + draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None if steps < 1. """ # Call the base model's inner model directly (avoids DynamicModule dispatch) model_output = self._base_model( diff --git a/tools/launcher/common/dflash/ar_eval_mtbench.sh b/tools/launcher/common/dflash/ar_eval_mtbench.sh index 3971f61b56..4062e6b2fe 100644 --- a/tools/launcher/common/dflash/ar_eval_mtbench.sh +++ b/tools/launcher/common/dflash/ar_eval_mtbench.sh @@ -14,212 +14,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -# MT-Bench per-category AR evaluation for DFlash checkpoints. -# Evaluates the latest checkpoint using ModelOpt's pseudo_speculative_generate -# with online (context-dependent) ground truth validation. +# MT-Bench AR evaluation using scripts/ar_validate.py. +# Finds the latest checkpoint and runs per-category AR validation. # -# Required env vars: -# HF_MODEL_CKPT — path to the target HuggingFace model -# -# Args: -# --ckpt_dir Path to directory containing checkpoint-* subdirs -# --block_size Block size for DFlash (default: 16) -# --num_layers Number of draft decoder layers (default: 5) -# --mask_token_id Mask token ID (default: auto-detect from checkpoint) -# --osl Output sequence length (default: 512) -# --steps Draft steps per block (default: block_size-1) -# --online Use online validation (default: true) +# Args are passed directly to ar_validate.py (--model_path, --osl, --steps, etc.) +# If --model_path is not provided, auto-detects from --ckpt_dir. SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" source ${SCRIPT_DIR}/../service_utils.sh pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt 2>&1 | tail -3 -# Overlay DFlash code -for INSTALL_PATH in $(python3 -c " -import modelopt, os, site -paths = set() -paths.add(os.path.dirname(modelopt.__file__)) -for sp in site.getsitepackages(): - p = os.path.join(sp, 'modelopt') - if os.path.isdir(p): paths.add(p) -for p in paths: print(p) -"); do - cp -rf modules/Model-Optimizer/modelopt/torch/speculative/dflash ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true - cp -f modules/Model-Optimizer/modelopt/torch/speculative/plugins/hf_dflash.py ${INSTALL_PATH}/torch/speculative/plugins/ 2>/dev/null || true - cp -f modules/Model-Optimizer/modelopt/torch/speculative/plugins/__init__.py ${INSTALL_PATH}/torch/speculative/plugins/ 2>/dev/null || true - cp -f modules/Model-Optimizer/modelopt/torch/speculative/config.py ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true - cp -f modules/Model-Optimizer/modelopt/torch/speculative/mode.py ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true - cp -f modules/Model-Optimizer/modelopt/torch/speculative/utils.py ${INSTALL_PATH}/torch/speculative/ 2>/dev/null || true -done +trap 'error_handler $0 $LINENO' ERR -# Parse args +# Parse --ckpt_dir to find latest checkpoint (ar_validate.py expects --model_path) +ARGS=() CKPT_DIR="" -BLOCK_SIZE=16 -NUM_LAYERS=5 -MASK_TOKEN_ID="" -OSL=512 -STEPS="" -ONLINE=true - while [ $# -gt 0 ]; do case "$1" in - --ckpt_dir) shift; CKPT_DIR="$1" ;; - --block_size) shift; BLOCK_SIZE="$1" ;; - --num_layers) shift; NUM_LAYERS="$1" ;; - --mask_token_id) shift; MASK_TOKEN_ID="$1" ;; - --osl) shift; OSL="$1" ;; - --steps) shift; STEPS="$1" ;; - --online) shift; ONLINE="$1" ;; - *) ;; + --ckpt_dir) shift; CKPT_DIR="$1" ;; + *) ARGS+=("$1") ;; esac shift done -if [ -z "$STEPS" ]; then - STEPS=$((BLOCK_SIZE - 1)) -fi - -MODEL=${HF_MODEL_CKPT} - -echo "=== DFlash MT-Bench AR Evaluation ===" -echo "Checkpoint dir: ${CKPT_DIR}" -echo "Model: ${MODEL}" -echo "Block size: ${BLOCK_SIZE}, Layers: ${NUM_LAYERS}" -echo "OSL: ${OSL}, Steps: ${STEPS}, Online: ${ONLINE}" - -# Find latest checkpoint -LAST_CKPT=$(ls -d ${CKPT_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) -if [ -z "$LAST_CKPT" ]; then - # Check for top-level model +# Auto-detect model_path from ckpt_dir if not explicitly provided +MODEL_PATH="" +if [ -n "$CKPT_DIR" ]; then + # Find latest checkpoint subdir + LAST_CKPT=$(ls -d ${CKPT_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) if [ -f "${CKPT_DIR}/model.safetensors" ]; then - LAST_CKPT=${CKPT_DIR} - else - echo "ERROR: No checkpoints found in ${CKPT_DIR}" - exit 1 + MODEL_PATH="${CKPT_DIR}" + elif [ -n "$LAST_CKPT" ]; then + MODEL_PATH="${LAST_CKPT}" fi + echo "Auto-detected model_path: ${MODEL_PATH}" fi -echo "Evaluating: ${LAST_CKPT}" - -CUDA_VISIBLE_DEVICES=0 python3 -c " -import torch, glob, os, json -from transformers import AutoModelForCausalLM, AutoTokenizer -from safetensors.torch import load_file -from datasets import load_dataset -from collections import defaultdict -import modelopt.torch.opt as mto -import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.plugins.transformers import HFARValidation - -mto.enable_huggingface_checkpointing() - -MODEL = '${MODEL}' -CKPT_PATH = '${LAST_CKPT}' -BLOCK_SIZE = ${BLOCK_SIZE} -NUM_LAYERS = ${NUM_LAYERS} -MASK_TOKEN_ID_STR = '${MASK_TOKEN_ID}' -OSL = ${OSL} -STEPS = ${STEPS} -ONLINE = '${ONLINE}' == 'true' -# Auto-detect mask_token_id from checkpoint config -MASK_TOKEN_ID = int(MASK_TOKEN_ID_STR) if MASK_TOKEN_ID_STR else None -if MASK_TOKEN_ID is None: - cfg_path = os.path.join(CKPT_PATH, 'config.json') - if os.path.isfile(cfg_path): - with open(cfg_path) as f: - ckpt_cfg = json.load(f) - dflash_cfg = ckpt_cfg.get('dflash_config', {}) - MASK_TOKEN_ID = dflash_cfg.get('mask_token_id') - if MASK_TOKEN_ID is None: - MASK_TOKEN_ID = 151669 # default for Qwen3 - print(f'WARNING: Could not auto-detect mask_token_id, using default {MASK_TOKEN_ID}') -print(f'Using mask_token_id={MASK_TOKEN_ID}') - -# Use flash_attention_2 if available -try: - import flash_attn - ATTN_IMPL = 'flash_attention_2' -except ImportError: - ATTN_IMPL = 'sdpa' -print(f'Using attn_implementation={ATTN_IMPL}') - -tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) - -# Load MT-Bench by category -ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] -cat_samples = defaultdict(list) -for i in range(len(ds)): - cat = ds[i].get('category', 'unknown') - cat_samples[cat].append(ds[i]['prompt'][0]) -categories = sorted(cat_samples.keys()) -print(f'Categories: {categories}') -for c in categories: - print(f' {c}: {len(cat_samples[c])} samples') - -# Load model -model = AutoModelForCausalLM.from_pretrained( - MODEL, torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True, - attn_implementation=ATTN_IMPL, -) -config = { - 'dflash_block_size': BLOCK_SIZE, - 'dflash_architecture_config': { - 'num_hidden_layers': NUM_LAYERS, - 'mask_token_id': MASK_TOKEN_ID, - '_attn_implementation': ATTN_IMPL, - }, - 'dflash_use_torch_compile': False, -} -mtsp.convert(model, [('dflash', config)]) - -# Load weights -sf_files = sorted(glob.glob(os.path.join(CKPT_PATH, 'model*.safetensors'))) -if sf_files: - state = {} - for f in sf_files: - state.update(load_file(f)) - dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} - if dflash_keys: - model.load_state_dict(dflash_keys, strict=False) - print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') - else: - model.dflash_module.load_state_dict(state, strict=False) - print(f'Loaded {len(state)} DFlash weights (no prefix)') -else: - print('ERROR: No safetensors found') - exit(1) - -model.eval() -validator = HFARValidation(model, tokenizer) - -# Evaluate per category -cat_ars = {} -all_ars = [] -for cat in categories: - ars = [] - for prompt in cat_samples[cat]: - chat = [{'role': 'user', 'content': prompt}] - text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) - input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() - try: - if ONLINE: - _, ar = validator.validate_online(osl=OSL, input_ids=input_ids, steps=STEPS) - else: - _, ar = validator.validate(osl=OSL, input_ids=input_ids, steps=STEPS) - ars.append(ar) - all_ars.append(ar) - except Exception as e: - print(f' ERROR [{cat}]: {e}') - cat_ars[cat] = sum(ars) / len(ars) if ars else 0.0 - -avg_all = sum(all_ars) / len(all_ars) if all_ars else 0.0 -mode_str = 'online' if ONLINE else 'fixed GT' +if [ -z "$MODEL_PATH" ]; then + echo "ERROR: No checkpoint found. Provide --ckpt_dir or --model_path." + exit 1 +fi -print(f'\n=== Results (OSL={OSL}, steps={STEPS}, {mode_str}) ===') -for c in categories: - print(f' {c:>12}: {cat_ars[c]:.4f}') -print(f'{\"ALL\":>14}: {avg_all:.4f}') -" +CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/ar_validate.py \ + --model_path "${MODEL_PATH}" \ + --per_category \ + "${ARGS[@]}" report_result "PASS: MT-Bench AR evaluation" diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml index 4da2e35b61..f3bb543bde 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -46,13 +46,8 @@ pipeline: script: common/dflash/ar_eval_mtbench.sh args: - --ckpt_dir /scratchspace/dflash_bs16 - - --block_size 16 - - --num_layers 5 - --osl 512 - --steps 15 - - --online true - environment: - - HF_MODEL_CKPT: <> slurm_config: _factory_: "slurm_factory" nodes: 1 From 6267fb59d7f675432988ed34bae7712a9afa8872 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 8 Apr 2026 20:21:52 -0700 Subject: [PATCH 06/24] add: export, PTQ, and vLLM scripts; document open items - export.sh: standalone checkpoint export to z-lab format - ptq_and_export.sh: FP8/NVFP4 quantization via hf_ptq.py - Fix rope_theta export (prefer draft_config over base_config) - Document vLLM integration gap, FP8/NVFP4 flow in dflash.md Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 11 ++++ .../torch/export/plugins/hf_spec_export.py | 4 +- tools/launcher/common/dflash/export.sh | 63 +++++++++++++++++++ .../launcher/common/dflash/ptq_and_export.sh | 59 +++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 tools/launcher/common/dflash/export.sh create mode 100644 tools/launcher/common/dflash/ptq_and_export.sh diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index e1e7007d44..4d1fb1bccd 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -200,6 +200,17 @@ quantized weights transparently. No DFlash-specific quantization code is needed. TODO: Add a quantization recipe/script and validate FP8/NVFP4 AR impact. +### vLLM Integration + +vLLM has a `DFlashProposer` in `v1/spec_decode/dflash.py`, but the model loader +does not yet recognize the `DFlashDraftModel` architecture from z-lab checkpoints. +The `specdec_bench` tool with `--speculative_algorithm EAGLE3` does not work for DFlash. + +Possible paths: +- Wait for vLLM to add `DFlashDraftModel` to their model registry +- Use vLLM's Python API directly with the DFlash proposer +- Convert checkpoint to a format vLLM recognizes (e.g., register as a custom model) + ### Docker Local Testing The launcher example currently requires Slurm cluster access. A local Docker example diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 28a858d18c..b107a3cb5a 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -311,7 +311,9 @@ def _export_config(self): "initializer_range": getattr(base_config, "initializer_range", 0.02), "attention_bias": getattr(draft_config, "attention_bias", False), "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), - "rope_theta": getattr(base_config, "rope_theta", 1000000.0), + "rope_theta": getattr( + draft_config, "rope_theta", getattr(base_config, "rope_theta", 1000000.0) + ), "rope_scaling": getattr(base_config, "rope_scaling", None), "tie_word_embeddings": False, "torch_dtype": "bfloat16", diff --git a/tools/launcher/common/dflash/export.sh b/tools/launcher/common/dflash/export.sh new file mode 100644 index 0000000000..730073596e --- /dev/null +++ b/tools/launcher/common/dflash/export.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Export speculative decoding checkpoint to deployment format. +# Auto-detects latest checkpoint and exports via export_hf_checkpoint.py. +# +# Args: +# --model_path Training output dir (auto-detects latest checkpoint) +# --export_path Destination directory + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt 2>&1 | tail -3 + +trap 'error_handler $0 $LINENO' ERR + +# Auto-detect latest checkpoint +MODEL_PATH="" +EXPORT_PATH="" +while [ $# -gt 0 ]; do + case "$1" in + --model_path) shift; MODEL_PATH="$1" ;; + --export_path) shift; EXPORT_PATH="$1" ;; + *) ;; + esac + shift +done + +# Find latest checkpoint if model_path is a training dir +if [ ! -f "${MODEL_PATH}/model.safetensors" ]; then + LAST_CKPT=$(ls -d ${MODEL_PATH}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) + if [ -n "$LAST_CKPT" ]; then + echo "Using latest checkpoint: $LAST_CKPT" + MODEL_PATH="$LAST_CKPT" + fi +fi + +echo "=== Export ===" +echo "Model: ${MODEL_PATH}" +echo "Export: ${EXPORT_PATH}" + +CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ + --model_path "${MODEL_PATH}" \ + --export_path "${EXPORT_PATH}" + +echo "Export contents:" +ls -lh ${EXPORT_PATH}/ + +report_result "PASS: Export" diff --git a/tools/launcher/common/dflash/ptq_and_export.sh b/tools/launcher/common/dflash/ptq_and_export.sh new file mode 100644 index 0000000000..ece933d686 --- /dev/null +++ b/tools/launcher/common/dflash/ptq_and_export.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# PTQ + export for speculative decoding checkpoints (EAGLE3, DFlash). +# Uses hf_ptq.py to quantize and export in one step. +# +# Args are passed directly to hf_ptq.py. + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt 2>&1 | tail -3 + +trap 'error_handler $0 $LINENO' ERR + +# Find latest checkpoint if model_dir points to a training output dir +MODEL_DIR="" +ARGS=() +while [ $# -gt 0 ]; do + case "$1" in + --model_dir) + shift + MODEL_DIR="$1" + # Auto-detect latest checkpoint + LAST_CKPT=$(ls -d ${MODEL_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) + if [ -f "${MODEL_DIR}/model.safetensors" ]; then + ARGS+=("--model_dir" "$MODEL_DIR") + elif [ -n "$LAST_CKPT" ]; then + echo "Using latest checkpoint: $LAST_CKPT" + ARGS+=("--model_dir" "$LAST_CKPT") + else + ARGS+=("--model_dir" "$MODEL_DIR") + fi + ;; + *) ARGS+=("$1") ;; + esac + shift +done + +echo "=== PTQ + Export ===" +echo "Args: ${ARGS[*]}" + +CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/llm_ptq/hf_ptq.py \ + "${ARGS[@]}" + +report_result "PASS: PTQ + Export" From 36955dcf7079dd233625d49e020dd2231af9115a Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 05:52:45 -0700 Subject: [PATCH 07/24] add: vLLM DFlash deployment script + validated 386 tok/s - vllm_serve.sh: launch vLLM server with DFlash spec dec + benchmark - Validated on vllm nightly (v0.19.1+) with z-lab checkpoint - 386 tok/s on single H100 (Qwen3-8B + DFlash-b16, 15 spec tokens) - Update dflash.md with deployment instructions Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 22 ++-- tools/launcher/common/dflash/vllm_serve.sh | 128 ++++++++++++++++++++ 2 files changed, 142 insertions(+), 8 deletions(-) create mode 100644 tools/launcher/common/dflash/vllm_serve.sh diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 4d1fb1bccd..4a26db7728 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -200,16 +200,22 @@ quantized weights transparently. No DFlash-specific quantization code is needed. TODO: Add a quantization recipe/script and validate FP8/NVFP4 AR impact. -### vLLM Integration +### vLLM Deployment -vLLM has a `DFlashProposer` in `v1/spec_decode/dflash.py`, but the model loader -does not yet recognize the `DFlashDraftModel` architecture from z-lab checkpoints. -The `specdec_bench` tool with `--speculative_algorithm EAGLE3` does not work for DFlash. +DFlash speculative decoding is supported in vLLM nightly (v0.19.1+): -Possible paths: -- Wait for vLLM to add `DFlashDraftModel` to their model registry -- Use vLLM's Python API directly with the DFlash proposer -- Convert checkpoint to a format vLLM recognizes (e.g., register as a custom model) +```bash +vllm serve Qwen/Qwen3-8B \ + --speculative-config '{"method": "dflash", "model": "z-lab/Qwen3-8B-DFlash-b16", "num_speculative_tokens": 15}' \ + --attention-backend flash_attn \ + --max-num-batched-tokens 32768 +``` + +Validated: **386 tok/s** on single H100 with Qwen3-8B + DFlash-b16 (15 spec tokens). + +Note: requires `vllm/vllm-openai:nightly` — the `latest` tag (v0.19.0) does not include DFlash. +See [`tools/launcher/common/dflash/vllm_serve.sh`](../../../tools/launcher/common/dflash/vllm_serve.sh) +for a complete serve + benchmark script. ### Docker Local Testing diff --git a/tools/launcher/common/dflash/vllm_serve.sh b/tools/launcher/common/dflash/vllm_serve.sh new file mode 100644 index 0000000000..973a79f380 --- /dev/null +++ b/tools/launcher/common/dflash/vllm_serve.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Launch vLLM server with DFlash speculative decoding, run benchmark, then shut down. +# +# Required env vars: +# HF_MODEL_CKPT — target model path +# DRAFT_MODEL — DFlash draft model path +# +# Optional env vars: +# NUM_SPEC_TOKENS — number of speculative tokens (default: 15) +# VLLM_PORT — server port (default: 8000) +# MAX_BATCHED_TOKENS — max batched tokens (default: 32768) +# BENCHMARK_PROMPTS — path to benchmark prompts file + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh 2>/dev/null || true + +trap 'kill $SERVER_PID 2>/dev/null; exit' EXIT ERR + +MODEL=${HF_MODEL_CKPT} +DRAFT=${DRAFT_MODEL} +NUM_SPEC=${NUM_SPEC_TOKENS:-15} +PORT=${VLLM_PORT:-8000} +MAX_TOKENS=${MAX_BATCHED_TOKENS:-32768} + +echo "=== vLLM DFlash Speculative Decoding ===" +echo "Target: ${MODEL}" +echo "Draft: ${DRAFT}" +echo "Spec tokens: ${NUM_SPEC}" + +# Start vLLM server in background +vllm serve ${MODEL} \ + --speculative-config "{\"method\": \"dflash\", \"model\": \"${DRAFT}\", \"num_speculative_tokens\": ${NUM_SPEC}}" \ + --attention-backend flash_attn \ + --max-num-batched-tokens ${MAX_TOKENS} \ + --port ${PORT} \ + & +SERVER_PID=$! + +# Wait for server to be ready +echo "Waiting for vLLM server to start..." +for i in $(seq 1 120); do + if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then + echo "Server ready after ${i}s" + break + fi + if ! kill -0 $SERVER_PID 2>/dev/null; then + echo "ERROR: Server process died" + wait $SERVER_PID + exit 1 + fi + sleep 1 +done + +if ! curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then + echo "ERROR: Server failed to start within 120s" + kill $SERVER_PID 2>/dev/null + exit 1 +fi + +# Run a quick test +echo "" +echo "=== Quick generation test ===" +curl -s http://localhost:${PORT}/v1/completions \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"${MODEL}\", + \"prompt\": \"What is 2+3?\", + \"max_tokens\": 64, + \"temperature\": 0 + }" | python3 -c "import json,sys; r=json.load(sys.stdin); print(r.get('choices',[{}])[0].get('text','ERROR')[:200]); print(f'Usage: {r.get(\"usage\",{})}')" + +# Run benchmark if prompts file provided +if [ -n "${BENCHMARK_PROMPTS}" ] && [ -f "${BENCHMARK_PROMPTS}" ]; then + echo "" + echo "=== MT-Bench Benchmark ===" + python3 -c " +import json, time, requests + +with open('${BENCHMARK_PROMPTS}') as f: + prompts = [json.loads(line) for line in f][:20] + +url = 'http://localhost:${PORT}/v1/completions' +times = [] +tokens = [] +for i, p in enumerate(prompts): + q = p.get('turns', [p.get('question', 'Hello')])[0] if isinstance(p, dict) else str(p) + start = time.time() + r = requests.post(url, json={ + 'model': '${MODEL}', + 'prompt': q, + 'max_tokens': 512, + 'temperature': 0, + }).json() + elapsed = time.time() - start + n = r.get('usage', {}).get('completion_tokens', 0) + times.append(elapsed) + tokens.append(n) + tps = n / elapsed if elapsed > 0 else 0 + print(f' [{i+1}/{len(prompts)}] {n} tokens in {elapsed:.1f}s = {tps:.1f} tok/s') + +total_tokens = sum(tokens) +total_time = sum(times) +print(f'\nTotal: {total_tokens} tokens in {total_time:.1f}s = {total_tokens/total_time:.1f} tok/s') +" +fi + +# Shut down server +echo "" +echo "Shutting down server..." +kill $SERVER_PID 2>/dev/null +wait $SERVER_PID 2>/dev/null || true + +echo "Done" From c20eabdffe7ff0c948fa55fb4000684fdb9456b4 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 06:55:17 -0700 Subject: [PATCH 08/24] fix: vllm_serve.sh - add TP_SIZE, fix prompt parsing, 1024 max tokens Co-Authored-By: Claude Opus 4.6 (1M context) --- tools/launcher/common/dflash/vllm_serve.sh | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tools/launcher/common/dflash/vllm_serve.sh b/tools/launcher/common/dflash/vllm_serve.sh index 973a79f380..1f108d2d38 100644 --- a/tools/launcher/common/dflash/vllm_serve.sh +++ b/tools/launcher/common/dflash/vllm_serve.sh @@ -36,17 +36,19 @@ DRAFT=${DRAFT_MODEL} NUM_SPEC=${NUM_SPEC_TOKENS:-15} PORT=${VLLM_PORT:-8000} MAX_TOKENS=${MAX_BATCHED_TOKENS:-32768} +TP=${TP_SIZE:-1} echo "=== vLLM DFlash Speculative Decoding ===" echo "Target: ${MODEL}" echo "Draft: ${DRAFT}" -echo "Spec tokens: ${NUM_SPEC}" +echo "Spec tokens: ${NUM_SPEC}, TP: ${TP}" # Start vLLM server in background vllm serve ${MODEL} \ --speculative-config "{\"method\": \"dflash\", \"model\": \"${DRAFT}\", \"num_speculative_tokens\": ${NUM_SPEC}}" \ --attention-backend flash_attn \ --max-num-batched-tokens ${MAX_TOKENS} \ + --tensor-parallel-size ${TP} \ --port ${PORT} \ & SERVER_PID=$! @@ -92,18 +94,18 @@ if [ -n "${BENCHMARK_PROMPTS}" ] && [ -f "${BENCHMARK_PROMPTS}" ]; then import json, time, requests with open('${BENCHMARK_PROMPTS}') as f: - prompts = [json.loads(line) for line in f][:20] + prompts = [json.loads(line) for line in f][:80] url = 'http://localhost:${PORT}/v1/completions' times = [] tokens = [] for i, p in enumerate(prompts): - q = p.get('turns', [p.get('question', 'Hello')])[0] if isinstance(p, dict) else str(p) + q = p.get('prompt', p.get('turns', [p.get('question', 'Hello')]))[0] if isinstance(p, dict) else str(p) start = time.time() r = requests.post(url, json={ 'model': '${MODEL}', 'prompt': q, - 'max_tokens': 512, + 'max_tokens': 1024, 'temperature': 0, }).json() elapsed = time.time() - start From 191c93fa4f3faa445180a8c9a82b0c3208d203eb Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 10:51:34 -0700 Subject: [PATCH 09/24] remove: HTML results page (lives in nmm-sandbox, not Model-Optimizer) Co-Authored-By: Claude Opus 4.6 (1M context) --- doc/results/dflash_results.html | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 doc/results/dflash_results.html diff --git a/doc/results/dflash_results.html b/doc/results/dflash_results.html deleted file mode 100644 index e69de29bb2..0000000000 From e27a9222243e6242b2c45731d880a2b705ff647c Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 11:03:01 -0700 Subject: [PATCH 10/24] update: DFlash unit tests for _apply meta fix, sliding window, mask_token_id error New tests: - TestDFlashApplyMetaFix: _apply recreates meta rotary, noop when normal - TestDFlashConvert.test_convert_missing_mask_token_id_errors: ValueError - TestDFlashSlidingWindow: reads layer_types, defaults to None - TestBuildTargetLayerIds.test_layer_ids_match_zlab: [1,9,17,25,33] Co-Authored-By: Claude Opus 4.6 (1M context) --- .../speculative/plugins/test_hf_dflash.py | 148 +++++++++++++----- 1 file changed, 106 insertions(+), 42 deletions(-) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 50d3c9768b..6bb6c0bb36 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -21,6 +21,7 @@ import os from copy import deepcopy +import pytest import torch from _test_utils.torch.transformers_models import ( get_tiny_llama, @@ -34,6 +35,7 @@ from modelopt.torch.speculative.plugins.hf_dflash import ( DFlashModule, HFDFlashModel, + build_target_layer_ids, create_dflash_attention_mask, create_dflash_loss_mask, ) @@ -110,6 +112,14 @@ def test_convert_sets_mask_token_id(self): assert hasattr(model, "mask_token_id") assert model.mask_token_id == 0 + def test_convert_missing_mask_token_id_errors(self): + """Test that missing mask_token_id raises ValueError for unknown model.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + del config["dflash_architecture_config"]["mask_token_id"] + with pytest.raises(ValueError, match="Cannot auto-detect mask_token_id"): + mtsp.convert(model, [("dflash", config)]) + class TestDFlashSaveRestore: """Test DFlash model save and restore.""" @@ -129,49 +139,69 @@ def test_save_and_restore(self, tmp_path): tf_modelopt_state_and_output_tester(model_ref, model_test) +class TestDFlashApplyMetaFix: + """Test DFlashModule._apply handles meta-tensor rotary buffers. + + During checkpoint restore, rotary inv_freq buffers may be on meta device + (they are computed, not saved). _apply should re-create them on CPU. + """ + + def test_apply_recreates_meta_rotary(self): + """Test that .to() recreates rotary_emb when buffers are on meta device.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + + dflash_mod = model.dflash_module + # Simulate meta buffers (as happens during checkpoint restore) + for name, buf in list(dflash_mod.rotary_emb.named_buffers()): + dflash_mod.rotary_emb._buffers[name] = torch.empty_like(buf, device="meta") + + assert any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) + + # .to() triggers _apply which should fix meta buffers + dflash_mod.to("cpu") + + assert not any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) + + def test_apply_noop_when_no_meta(self): + """Test that .to() does not recreate rotary_emb when buffers are normal.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + + dflash_mod = model.dflash_module + rotary_id_before = id(dflash_mod.rotary_emb) + dflash_mod.to("cpu") + assert id(dflash_mod.rotary_emb) == rotary_id_before + + class TestDFlashAttentionMask: """Test DFlash attention mask construction.""" def test_mask_shape(self): - """Test mask has shape [1, 1, L, 2L].""" mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) assert mask.shape == (1, 1, SEQ_LEN, 2 * SEQ_LEN) def test_mask_context_strictly_previous_blocks(self): """Context (left half): block B can only see blocks 0..B-1.""" mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) - mask_2d = mask[0, 0] # [8, 16] - ctx_mask = mask_2d[:, :8] # context part - - # Block 0 (rows 0-3) should NOT see any context + mask_2d = mask[0, 0] + ctx_mask = mask_2d[:, :8] assert (ctx_mask[:4, :] < 0).all() - - # Block 1 (rows 4-7) should see block 0 context only - assert (ctx_mask[4:8, :4] == 0).all() # can see block 0 - assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block + assert (ctx_mask[4:8, :4] == 0).all() + assert (ctx_mask[4:8, 4:8] < 0).all() def test_mask_noise_causal_within_block(self): - """Noise (right half): reverse-causal within same block, matching SpecForge. - - SpecForge uses j >= i: position 0 (anchor) sees all positions in block, - position B-1 sees only itself. Cross-block noise is fully masked. - """ + """Noise (right half): reverse-causal within same block (j >= i).""" mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) - mask_2d = mask[0, 0] - noise_mask = mask_2d[:, 8:] # noise part - - # Block 0, position 0: can see all positions in block (0-3) + noise_mask = mask[0, 0, :, 8:] assert (noise_mask[0, :4] == 0).all() - - # Block 0, position 3: can only see position 3 assert (noise_mask[3, :3] < 0).all() assert noise_mask[3, 3] == 0 - - # Block 1 cannot see block 0 noise assert (noise_mask[4:8, :4] < 0).all() def test_mask_values_are_zero_or_neg_inf(self): - """Test mask contains only 0 (attend) and -inf (mask).""" mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) unique_vals = mask.unique() assert len(unique_vals) == 2 @@ -183,31 +213,26 @@ class TestDFlashLossMask: """Test DFlash loss mask construction.""" def test_loss_mask_shape(self): - """Test loss mask has shape [L].""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") assert mask.shape == (SEQ_LEN,) def test_loss_mask_excludes_block_zero(self): - """Test all positions in block 0 are masked out.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") assert (mask[:BLOCK_SIZE] == 0).all() def test_loss_mask_excludes_block_starts(self): - """Test block start positions are masked.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") for i in range(0, SEQ_LEN, BLOCK_SIZE): - assert mask[i] == 0, f"Block start position {i} should be masked" + assert mask[i] == 0 def test_loss_mask_includes_non_start_positions(self): - """Test non-start positions in non-zero blocks are included.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") for b in range(1, SEQ_LEN // BLOCK_SIZE): for offset in range(1, BLOCK_SIZE): pos = b * BLOCK_SIZE + offset - assert mask[pos] == 1, f"Position {pos} should be in loss" + assert mask[pos] == 1 def test_loss_mask_count(self): - """Test total active positions matches expected count.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") num_blocks = SEQ_LEN // BLOCK_SIZE expected = (num_blocks - 1) * (BLOCK_SIZE - 1) @@ -218,26 +243,65 @@ class TestBuildTargetLayerIds: """Test target layer selection.""" def test_single_draft_layer(self): - """Test single draft layer selects middle target layer.""" - from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids - ids = build_target_layer_ids(32, 1) assert len(ids) == 1 - assert ids[0] == 16 # middle layer + assert ids[0] == 16 def test_multiple_draft_layers(self): - """Test multiple draft layers are monotonically increasing and in bounds.""" - from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids - ids = build_target_layer_ids(36, 5) assert len(ids) == 5 assert ids == sorted(ids) assert all(1 <= lid <= 33 for lid in ids) - def test_layer_ids_spread(self): - """Test layer IDs have no duplicates.""" - from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids - + def test_layer_ids_no_duplicates(self): ids = build_target_layer_ids(32, 5) - assert len(ids) == 5 assert len(set(ids)) == 5 + + def test_layer_ids_match_zlab(self): + """Test layer IDs match z-lab reference for Qwen3-8B (36 layers, 5 draft).""" + ids = build_target_layer_ids(36, 5) + assert ids == [1, 9, 17, 25, 33] + + +class TestDFlashSlidingWindow: + """Test sliding window attention support.""" + + def test_sliding_window_from_config(self): + """Test DFlashAttention reads sliding_window from config.layer_types.""" + from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention + from transformers import PretrainedConfig + + config = PretrainedConfig( + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=16, + rms_norm_eps=1e-6, + attention_bias=False, + attention_dropout=0.0, + layer_types=["full_attention", "sliding_attention"], + sliding_window=256, + _attn_implementation="sdpa", + ) + attn_full = DFlashAttention(config, layer_idx=0) + attn_sliding = DFlashAttention(config, layer_idx=1) + assert attn_full.sliding_window is None + assert attn_sliding.sliding_window == 256 + + def test_no_sliding_window_without_config(self): + """Test DFlashAttention defaults to no sliding window.""" + from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention + from transformers import PretrainedConfig + + config = PretrainedConfig( + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=16, + rms_norm_eps=1e-6, + attention_bias=False, + attention_dropout=0.0, + _attn_implementation="sdpa", + ) + attn = DFlashAttention(config, layer_idx=0) + assert attn.sliding_window is None From 437806bfb142db109dd2ddc7de40a9fa8151835e Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 11:44:26 -0700 Subject: [PATCH 11/24] update: dflash.md with AR evaluation methods and vLLM per-category results - Detailed explanation of 3 AR evaluation methods (ModelOpt online GT, vLLM, z-lab) - vLLM deployment results: 3.1x speedup, per-category accept length + TPS - ModelOpt wins 7/8 categories on accept length, 8/8 on TPS Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 81 +++++++++++++++++---- 1 file changed, 66 insertions(+), 15 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 4a26db7728..6204f22174 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -132,26 +132,79 @@ Trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples), 64 GPUs, 10 ep > z-lab trained with block_size=16; ModelOpt trained with block_size=8. -### Evaluation Methods +### AR Evaluation Methods -| Method | Description | -|--------|-------------| -| **Fixed GT** | Pre-compute greedy ground truth, check draft against it | -| **Online GT** | Recompute ground truth after each accepted draft (context-dependent) | -| **z-lab official** | Actual speculative decoding with draft KV cache | +Three methods exist for measuring acceptance rate, producing different numbers: -Online GT is more accurate than Fixed GT (~+1.0 AR) because speculative decoding -acceptance depends on context-dependent verification, not a fixed reference sequence. +**1. ModelOpt Online GT** (`ar_validate.py --per_category`) + +The default evaluation method in ModelOpt. Uses `pseudo_speculative_generate` with +context-dependent (online) ground truth: + +1. Run base model on `input_ids` → get base token + hidden states +2. Build draft block: `[base_token, MASK, MASK, ...]` +3. Run DFlash draft forward → get `block_size-1` draft tokens +4. Run base model on `input_ids + base_token + draft_tokens` → verify each draft token + against what the base model would produce **given the accepted sequence so far** +5. Accept consecutive matches, append target's correction token on first mismatch +6. AR = total accepted tokens / number of speculative steps + +Key: ground truth is **recomputed after each accepted draft** (context-dependent), +matching actual speculative decoding behavior. Without this, fixed ground truth +underestimates AR by ~1.0 because it doesn't account for the draft model's +self-consistency — accepted draft tokens change the context for future predictions. + +**2. vLLM SpecDecoding** (`vllm serve --speculative-config`) + +The production evaluation. vLLM runs actual speculative decoding with full KV cache: +- Draft model proposes `num_speculative_tokens` tokens per step +- Target model verifies in parallel +- `Mean acceptance length` and `Per-position acceptance rate` reported in server metrics + +**3. z-lab Benchmark** (`dflash.benchmark --backend transformers`) + +z-lab's reference evaluation. Similar to vLLM but uses their own generation loop with +draft KV cache. Measures acceptance length as: +``` +acceptance_length = (draft[:, 1:] == posterior[:, :-1]).cumprod().sum() + 1 +``` + +### vLLM Deployment Results + +vLLM nightly (v0.19.1+), H100, MT-Bench 80 prompts, 1024 max tokens: + +| | Baseline | z-lab (bs16) | **ModelOpt (bs8)** | +|---|---------|-------------|-------------------| +| TP=1 tok/s | 145 | 422 | **443** | +| TP=8 tok/s | 377 | 919 | **1053** | +| Speedup (TP=1) | 1.0x | 2.9x | **3.1x** | + +**Per-Category (TP=8):** + +| Category | ModelOpt Accept | z-lab Accept | ModelOpt TPS | z-lab TPS | +|----------|----------------|-------------|-------------|-----------| +| math | **5.14** | 4.24 | **1238** | 1098 | +| coding | **4.03** | 3.52 | **1299** | 1269 | +| writing | **3.99** | 3.97 | **1002** | 903 | +| reasoning | **3.89** | 3.49 | **1188** | 1020 | +| roleplay | **3.88** | 3.37 | **1069** | 923 | +| extraction | **3.60** | 3.02 | **1002** | 789 | +| stem | 3.55 | **3.63** | **1027** | 914 | +| humanities | **3.05** | 2.68 | **786** | 672 | +| **ALL** | | | **1053** | 919 | + +ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories. ### Key Findings | Finding | Evidence | |---------|----------| +| 3.1x speedup over baseline (TP=1) | 443 vs 145 tok/s on vLLM | +| 15% faster than z-lab | TP=1: 443 vs 422; TP=8: 1053 vs 919 | +| More efficient drafting | 44% vs 16.5% draft acceptance; fewer tokens drafted, more accepted | | Loss decay boosts AR | +0.12 AR at 55K (gamma=7, bs16); consistent across checkpoints | | Longer sequences help | seq=4096 vs 512: +0.49 AR on AA-Synthetic | -| Online validation essential | Fixed GT underestimates by ~1.0 AR | -| Forward pass identical to z-lab | Max diff 0.5 (bf16); 6/7 draft tokens match | -| sdpa vs flash_attn: negligible | AR 3.31 vs 3.31; hidden states identical | +| Online GT essential | Fixed GT underestimates by ~1.0 AR vs online GT | ## Open Items @@ -206,16 +259,14 @@ DFlash speculative decoding is supported in vLLM nightly (v0.19.1+): ```bash vllm serve Qwen/Qwen3-8B \ - --speculative-config '{"method": "dflash", "model": "z-lab/Qwen3-8B-DFlash-b16", "num_speculative_tokens": 15}' \ + --speculative-config '{"method": "dflash", "model": "path/to/dflash-checkpoint", "num_speculative_tokens": 7}' \ --attention-backend flash_attn \ --max-num-batched-tokens 32768 ``` -Validated: **386 tok/s** on single H100 with Qwen3-8B + DFlash-b16 (15 spec tokens). - Note: requires `vllm/vllm-openai:nightly` — the `latest` tag (v0.19.0) does not include DFlash. See [`tools/launcher/common/dflash/vllm_serve.sh`](../../../tools/launcher/common/dflash/vllm_serve.sh) -for a complete serve + benchmark script. +for serve + benchmark scripts. ### Docker Local Testing From eb0c3e738c40dd5d74dbff0079457016ee8e1745 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 11:51:10 -0700 Subject: [PATCH 12/24] simplify: remove fixed GT vs online GT discussion from dflash.md Online GT is the only evaluation method. No need to discuss alternatives. Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 77 +++------------------ 1 file changed, 8 insertions(+), 69 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 6204f22174..aac4f798a2 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -100,73 +100,21 @@ Trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples), 64 GPUs, 10 ep | Total Steps | 306,620 | | Final Per-Token Acc | 67.0% | -### MT-Bench Per-Category AR (Online Validation, osl=512) +### AR Evaluation -| Category | 80K | 150K | 306K | -|----------|-----|------|------| -| math | 5.44 | 5.54 | **5.52** | -| extraction | 4.81 | 4.82 | **4.88** | -| coding | 4.40 | 4.53 | **4.60** | -| reasoning | 4.34 | 4.41 | **4.44** | -| stem | 4.05 | 4.15 | **4.17** | -| writing | 3.76 | 3.79 | **3.84** | -| roleplay | 3.58 | 3.73 | **3.78** | -| humanities | 3.55 | 3.62 | **3.65** | -| **ALL** | **4.24** | **4.32** | **4.36** | - -### Comparison with z-lab/Qwen3-8B-DFlash-b16 - -**ModelOpt eval (online validation, osl=512):** - -| Dataset | z-lab | ModelOpt | Diff | -|---------|-------|----------|------| -| gsm8k | 4.10 | **5.19** | **+1.09** | -| MT-Bench | 3.58 | **4.36** | **+0.78** | - -**z-lab official eval (dflash.benchmark, osl=512):** - -| Dataset | z-lab | ModelOpt | Diff | -|---------|-------|----------|------| -| gsm8k | **5.00** | 4.08 | -0.92 | -| MT-Bench | **3.28** | 2.99 | -0.29 | - -> z-lab trained with block_size=16; ModelOpt trained with block_size=8. - -### AR Evaluation Methods - -Three methods exist for measuring acceptance rate, producing different numbers: - -**1. ModelOpt Online GT** (`ar_validate.py --per_category`) - -The default evaluation method in ModelOpt. Uses `pseudo_speculative_generate` with -context-dependent (online) ground truth: +AR is evaluated using `ar_validate.py` which calls `pseudo_speculative_generate` +with online (context-dependent) ground truth: 1. Run base model on `input_ids` → get base token + hidden states 2. Build draft block: `[base_token, MASK, MASK, ...]` 3. Run DFlash draft forward → get `block_size-1` draft tokens -4. Run base model on `input_ids + base_token + draft_tokens` → verify each draft token - against what the base model would produce **given the accepted sequence so far** -5. Accept consecutive matches, append target's correction token on first mismatch +4. Verify each draft token against the base model's prediction **given the + accepted sequence so far** (not a pre-computed fixed reference) +5. Accept consecutive matches, append target's correction on first mismatch 6. AR = total accepted tokens / number of speculative steps -Key: ground truth is **recomputed after each accepted draft** (context-dependent), -matching actual speculative decoding behavior. Without this, fixed ground truth -underestimates AR by ~1.0 because it doesn't account for the draft model's -self-consistency — accepted draft tokens change the context for future predictions. - -**2. vLLM SpecDecoding** (`vllm serve --speculative-config`) - -The production evaluation. vLLM runs actual speculative decoding with full KV cache: -- Draft model proposes `num_speculative_tokens` tokens per step -- Target model verifies in parallel -- `Mean acceptance length` and `Per-position acceptance rate` reported in server metrics - -**3. z-lab Benchmark** (`dflash.benchmark --backend transformers`) - -z-lab's reference evaluation. Similar to vLLM but uses their own generation loop with -draft KV cache. Measures acceptance length as: -``` -acceptance_length = (draft[:, 1:] == posterior[:, :-1]).cumprod().sum() + 1 +```bash +python scripts/ar_validate.py --model_path /path/to/checkpoint --per_category --osl 512 --steps 7 ``` ### vLLM Deployment Results @@ -204,7 +152,6 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories. | More efficient drafting | 44% vs 16.5% draft acceptance; fewer tokens drafted, more accepted | | Loss decay boosts AR | +0.12 AR at 55K (gamma=7, bs16); consistent across checkpoints | | Longer sequences help | seq=4096 vs 512: +0.49 AR on AA-Synthetic | -| Online GT essential | Fixed GT underestimates by ~1.0 AR vs online GT | ## Open Items @@ -224,14 +171,6 @@ in bf16. With 2M samples, full pre-computation would require ~320TB — not feas - Logit distillation adds another dimension: teacher logits at anchor+k-1 positions need `[seq_len, vocab_size]` per sample (~600MB in bf16) -### z-lab Eval Gap - -ModelOpt eval (online GT) gives higher AR than z-lab's official eval on our checkpoint -(5.19 vs 4.08 on gsm8k). The gap is likely from: -- z-lab uses draft KV cache (accumulates context across blocks); our eval re-runs from scratch -- z-lab's `acceptance_length + 1` counting (minimum 1 per step) -- `rope_theta` mismatch in exported config (was 10000 instead of 1000000 — now fixed) - ### Model Support Expansion Currently supports Qwen3 draft architecture. See `hf_dflash.py` module docstring for From af65f3e76b689e8643ad9b0364972441f0d551e1 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 11:59:45 -0700 Subject: [PATCH 13/24] address PR review: clarify KV injection, anchors, decay, offline, vLLM Review comments from @benchislett: - Clarify Q/K/V roles in DFlashAttention docstring - Explain num_anchors tradeoff and scaling guidance - Distinguish DFlash decay from EAGLE3 decay - Update offline training note re: multi-layer vs last-layer hidden states - Remove --attention-backend from vLLM (auto-selected) - Clarify logit distillation description Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 43 ++- .../torch/speculative/plugins/hf_dflash.py | 11 +- tools/launcher/common/dflash/vllm_serve.sh | 34 +- uv.lock | 290 +++++++++++++++++- 4 files changed, 341 insertions(+), 37 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index aac4f798a2..2abc66c97d 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -50,24 +50,41 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | Parameter | Default | Description | |-----------|---------|-------------| | `dflash.dflash_block_size` | 8 | Block size for parallel prediction | -| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample | -| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables) | -| `dflash.dflash_self_logit_distillation` | true | Logit distillation from target | +| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | +| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | +| `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | | `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | | `training.answer_only_loss` | false | Mask loss on non-assistant tokens | +### Random Anchor Sampling (`num_anchors`) + +During training, anchor positions are sampled randomly from valid (assistant response) +tokens in each batch, rather than dividing the sequence into fixed blocks. Each anchor +starts a block of `block_size` tokens where the draft model predicts positions 1..B-1. + +**Tradeoff:** Higher `num_anchors` = more training signal per sample but more compute. +Lower = faster iteration but less data efficiency. With `seq_len=4096` and `block_size=8`, +`num_anchors=512` means the model sees ~512 blocks per sample (covering ~4096 positions). +Scale proportionally: `num_anchors ≈ seq_len / block_size` gives full coverage. + ### Loss Decay The exponential decay factor (gamma) weights early block positions higher than later ones. -If position 0 in a block is wrong, all subsequent positions are rejected in speculative +If position 1 in a block is wrong, all subsequent positions are rejected in speculative decoding. Decay aligns the training loss with what matters for acceptance rate. ``` -weight[k] = exp(-k / gamma) for k = 0..B-1 +weight[k] = exp(-(k-1).clamp(min=0) / gamma) for k = 0..B-1 ``` -Paper recommendation: gamma=7 for block_size=16, gamma=4 for block_size=8. +Positions 0 (anchor, excluded by loss mask) and 1 get full weight (1.0). Later positions +decay: e.g., with `gamma=4` and `block_size=8`, position 7 contributes only 22% as +much as position 1. Paper recommendation: gamma=7 for block_size=16, gamma=4 for block_size=8. + +Note: this is different from EAGLE3's `eagle_loss_decay_factor` which multiplies loss by +`alpha^step` across TTT steps. DFlash decay operates within a single block, weighting +early positions higher because they gate acceptance of all later positions. ### Checkpoint Resume @@ -160,16 +177,15 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories. Online training requires the full target model in GPU memory alongside the draft model. Offline training would pre-compute target hidden states and train the draft model separately. -**Challenge**: DFlash uses random anchor sampling over full sequences, requiring hidden states -at ALL positions. For Qwen3-8B with 5 target layers and seq_len=4096, this is ~160MB per sample -in bf16. With 2M samples, full pre-computation would require ~320TB — not feasible. +**Challenge**: DFlash needs hidden states from multiple target layers (not just the last) +at all positions for KV injection. EAGLE3 offline only stores last-layer hidden states +and reruns `lm_head` during training, but DFlash's feature fusion concatenates hidden +states from layers [1, 9, 17, 25, 33] — 5x the storage per position. **Potential approaches:** -- Pre-sample anchor positions and store only relevant slices (limits randomness) -- Stream hidden states from disk with chunked loading +- Store only the fused (post-FC) target hidden states instead of raw multi-layer states +- Pre-sample anchor positions and store only relevant slices - Hybrid: quantized base model on CPU computes hidden states on-the-fly, draft on GPU -- Logit distillation adds another dimension: teacher logits at anchor+k-1 positions - need `[seq_len, vocab_size]` per sample (~600MB in bf16) ### Model Support Expansion @@ -199,7 +215,6 @@ DFlash speculative decoding is supported in vLLM nightly (v0.19.1+): ```bash vllm serve Qwen/Qwen3-8B \ --speculative-config '{"method": "dflash", "model": "path/to/dflash-checkpoint", "num_speculative_tokens": 7}' \ - --attention-backend flash_attn \ --max-num-batched-tokens 32768 ``` diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index c26d800bc9..20add0f7c9 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -165,11 +165,16 @@ def _eager_attention(self, module, q, k, v, attention_mask, **kwargs): return attn_output, None def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): - """Forward with KV injection: Q from noise, K/V from context+noise.""" + """Forward with KV injection. + + Q is projected from the noise block (draft token embeddings: [anchor, mask, mask, ...]). + K and V are projected from the concatenation of target hidden states (context from the + base model) and noise block, so the draft can attend to both context and its own block. + """ bsz, q_len, _ = hidden_states.shape ctx_len = target_hidden.shape[1] - # Q from noise only, with QK-norm + # Q from noise block only (the draft tokens being predicted), with QK-norm q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) q = self.q_norm(q).transpose(1, 2) @@ -336,6 +341,8 @@ def create_dflash_attention_mask( # Create in f32 then cast, matching SpecForge. This ensures masked # positions get -inf in bf16 (f32 min overflows to -inf when cast), # not the largest finite negative bf16 value. + # TODO: This f32→bf16 cast pattern may be simplified once the mask behavior is stable. + # Consider creating directly in the target dtype with appropriate fill value. full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=torch.float32) full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) full_mask = full_mask.to(dtype=dtype) diff --git a/tools/launcher/common/dflash/vllm_serve.sh b/tools/launcher/common/dflash/vllm_serve.sh index 1f108d2d38..a7e5857f09 100644 --- a/tools/launcher/common/dflash/vllm_serve.sh +++ b/tools/launcher/common/dflash/vllm_serve.sh @@ -29,7 +29,8 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" source ${SCRIPT_DIR}/../service_utils.sh 2>/dev/null || true -trap 'kill $SERVER_PID 2>/dev/null; exit' EXIT ERR +cleanup() { kill $SERVER_PID 2>/dev/null; sleep 2; kill -9 $SERVER_PID 2>/dev/null; } +trap cleanup EXIT MODEL=${HF_MODEL_CKPT} DRAFT=${DRAFT_MODEL} @@ -46,7 +47,6 @@ echo "Spec tokens: ${NUM_SPEC}, TP: ${TP}" # Start vLLM server in background vllm serve ${MODEL} \ --speculative-config "{\"method\": \"dflash\", \"model\": \"${DRAFT}\", \"num_speculative_tokens\": ${NUM_SPEC}}" \ - --attention-backend flash_attn \ --max-num-batched-tokens ${MAX_TOKENS} \ --tensor-parallel-size ${TP} \ --port ${PORT} \ @@ -92,15 +92,17 @@ if [ -n "${BENCHMARK_PROMPTS}" ] && [ -f "${BENCHMARK_PROMPTS}" ]; then echo "=== MT-Bench Benchmark ===" python3 -c " import json, time, requests +from collections import defaultdict with open('${BENCHMARK_PROMPTS}') as f: prompts = [json.loads(line) for line in f][:80] url = 'http://localhost:${PORT}/v1/completions' -times = [] -tokens = [] +cat_results = defaultdict(lambda: {'tokens': [], 'times': []}) + for i, p in enumerate(prompts): q = p.get('prompt', p.get('turns', [p.get('question', 'Hello')]))[0] if isinstance(p, dict) else str(p) + cat = p.get('category', 'unknown') if isinstance(p, dict) else 'unknown' start = time.time() r = requests.post(url, json={ 'model': '${MODEL}', @@ -110,14 +112,26 @@ for i, p in enumerate(prompts): }).json() elapsed = time.time() - start n = r.get('usage', {}).get('completion_tokens', 0) - times.append(elapsed) - tokens.append(n) + cat_results[cat]['tokens'].append(n) + cat_results[cat]['times'].append(elapsed) tps = n / elapsed if elapsed > 0 else 0 - print(f' [{i+1}/{len(prompts)}] {n} tokens in {elapsed:.1f}s = {tps:.1f} tok/s') + print(f' [{i+1}/{len(prompts)}] [{cat}] {n} tokens in {elapsed:.1f}s = {tps:.1f} tok/s') -total_tokens = sum(tokens) -total_time = sum(times) -print(f'\nTotal: {total_tokens} tokens in {total_time:.1f}s = {total_tokens/total_time:.1f} tok/s') +print(f'\n=== Per-Category Results ===') +print(f'{\"Category\":>12} | {\"Prompts\":>7} | {\"Tokens\":>8} | {\"Time(s)\":>8} | {\"TPS\":>8}') +print('-' * 55) +all_tokens = 0 +all_time = 0 +for cat in sorted(cat_results): + t = sum(cat_results[cat]['tokens']) + s = sum(cat_results[cat]['times']) + n = len(cat_results[cat]['tokens']) + tps = t / s if s > 0 else 0 + all_tokens += t + all_time += s + print(f'{cat:>12} | {n:>7} | {t:>8} | {s:>8.1f} | {tps:>8.1f}') +print('-' * 55) +print(f'{\"ALL\":>12} | {sum(len(v[\"tokens\"]) for v in cat_results.values()):>7} | {all_tokens:>8} | {all_time:>8.1f} | {all_tokens/all_time:>8.1f}') " fi diff --git a/uv.lock b/uv.lock index d890e361cb..2b2b274e13 100644 --- a/uv.lock +++ b/uv.lock @@ -20,9 +20,6 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] -[manifest] -overrides = [{ name = "torch", marker = "sys_platform == 'never'" }] - [[package]] name = "accelerate" version = "1.13.0" @@ -35,7 +32,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } wheels = [ @@ -480,6 +477,21 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/54/27/01d9078a77b9e31b79b9716e66ca4db74f4744c5232bcb3e8769395c4280/cppimport-22.8.2.tar.gz", hash = "sha256:bbb4957102db41bc99ad72c233bce92f9d1fd91be352fc07878c4361033a401f", size = 26635, upload-time = "2022-08-02T16:50:36.872Z" } +[[package]] +name = "cuda-bindings" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/45/e7/b47792cc2d01c7e1d37c32402182524774dadd2d26339bd224e0e913832e/cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c912a3d9e6b6651853eed8eed96d6800d69c08e94052c292fec3f282c5a817c9", size = 12210593, upload-time = "2025-10-21T14:51:36.574Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, + { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, + { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, +] + [[package]] name = "cuda-pathfinder" version = "1.4.3" @@ -554,7 +566,7 @@ dependencies = [ { name = "psutil", marker = "sys_platform != 'win32'" }, { name = "py-cpuinfo", marker = "sys_platform != 'win32'" }, { name = "pydantic", marker = "sys_platform != 'win32'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch", marker = "sys_platform != 'win32'" }, { name = "tqdm", marker = "sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/11/46b9eb3806ca7a5e9bdddb7e873855a2d59a9f87f0675ae8231678d98434/deepspeed-0.18.8.tar.gz", hash = "sha256:e4e051a144b0c74270c46e4970139f9a86a61ff26959c5e463000c4a93b99304", size = 1647226, upload-time = "2026-03-13T18:49:48.568Z" } @@ -1311,7 +1323,9 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version < '3.11' and sys_platform == 'darwin')", + "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'win32'", "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } @@ -1324,12 +1338,18 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version >= '3.13' and sys_platform == 'darwin')", "(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.12.*' and sys_platform == 'darwin')", "(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.11.*' and sys_platform == 'darwin')", "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1526,6 +1546,108 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/a7/b35835e278c18b85206834b3aa3abe68e77a98769c59233d1f6300284781/numpy-2.4.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:4b42639cdde6d24e732ff823a3fa5b701d8acad89c4142bc1d0bd6dc85200ba5", size = 12504685, upload-time = "2026-03-09T07:58:50.525Z" }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + [[package]] name = "nvidia-ml-py" version = "13.595.45" @@ -1554,7 +1676,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "setuptools" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, ] @@ -1582,6 +1704,7 @@ all = [ { name = "peft" }, { name = "polygraphy" }, { name = "sentencepiece" }, + { name = "tiktoken" }, { name = "transformers" }, { name = "wonderwords" }, ] @@ -1589,6 +1712,7 @@ dev = [ { name = "accelerate" }, { name = "autodoc-pydantic" }, { name = "bandit", extra = ["toml"] }, + { name = "coverage", extra = ["toml"] }, { name = "cppimport" }, { name = "cupy-cuda12x", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin'" }, { name = "datasets" }, @@ -1625,6 +1749,7 @@ dev = [ { name = "sphinx-inline-tabs" }, { name = "sphinx-rtd-theme" }, { name = "sphinx-togglebutton" }, + { name = "tiktoken" }, { name = "timm" }, { name = "torch-geometric" }, { name = "torchprofile" }, @@ -1652,6 +1777,7 @@ dev-lint = [ { name = "ruff" }, ] dev-test = [ + { name = "coverage", extra = ["toml"] }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-instafail" }, @@ -1672,6 +1798,7 @@ hf = [ { name = "nltk" }, { name = "peft" }, { name = "sentencepiece" }, + { name = "tiktoken" }, { name = "transformers" }, { name = "wonderwords" }, ] @@ -1697,6 +1824,7 @@ requires-dist = [ { name = "accelerate", marker = "extra == 'hf'", specifier = ">=1.0.0" }, { name = "autodoc-pydantic", marker = "extra == 'dev-docs'", specifier = ">=2.1.0" }, { name = "bandit", extras = ["toml"], marker = "extra == 'dev-lint'", specifier = "==1.7.9" }, + { name = "coverage", extras = ["toml"], marker = "extra == 'dev-test'", specifier = ">=7.13.0" }, { name = "cppimport", marker = "extra == 'onnx'" }, { name = "cupy-cuda12x", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and extra == 'onnx'" }, { name = "datasets", marker = "extra == 'hf'", specifier = ">=3.0.0" }, @@ -1748,6 +1876,7 @@ requires-dist = [ { name = "sphinx-inline-tabs", marker = "extra == 'dev-docs'", specifier = ">=2023.4.21" }, { name = "sphinx-rtd-theme", marker = "extra == 'dev-docs'", specifier = "~=3.0.0" }, { name = "sphinx-togglebutton", marker = "extra == 'dev-docs'", specifier = ">=0.3.2" }, + { name = "tiktoken", marker = "extra == 'hf'" }, { name = "timm", marker = "extra == 'dev-test'" }, { name = "torch", specifier = ">=2.6" }, { name = "torch-geometric", marker = "extra == 'dev-test'" }, @@ -1756,11 +1885,43 @@ requires-dist = [ { name = "tox", marker = "extra == 'dev-test'", specifier = ">4.18" }, { name = "tox-current-env", marker = "extra == 'dev-test'", specifier = ">=0.0.12" }, { name = "tqdm" }, - { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53,<5.0" }, + { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.56,<5.0" }, { name = "wonderwords", marker = "extra == 'hf'" }, ] provides-extras = ["onnx", "hf", "dev-lint", "dev-docs", "dev-test", "all", "dev"] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu12" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + [[package]] name = "omegaconf" version = "2.3.0" @@ -2157,7 +2318,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -3395,6 +3556,53 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "tiktoken" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "regex" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/b3/2cb7c17b6c4cf8ca983204255d3f1d95eda7213e247e6947a0ee2c747a2c/tiktoken-0.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3de02f5a491cfd179aec916eddb70331814bd6bf764075d39e21d5862e533970", size = 1051991, upload-time = "2025-10-06T20:21:34.098Z" }, + { url = "https://files.pythonhosted.org/packages/27/0f/df139f1df5f6167194ee5ab24634582ba9a1b62c6b996472b0277ec80f66/tiktoken-0.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b6cfb6d9b7b54d20af21a912bfe63a2727d9cfa8fbda642fd8322c70340aad16", size = 995798, upload-time = "2025-10-06T20:21:35.579Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5d/26a691f28ab220d5edc09b9b787399b130f24327ef824de15e5d85ef21aa/tiktoken-0.12.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:cde24cdb1b8a08368f709124f15b36ab5524aac5fa830cc3fdce9c03d4fb8030", size = 1129865, upload-time = "2025-10-06T20:21:36.675Z" }, + { url = "https://files.pythonhosted.org/packages/b2/94/443fab3d4e5ebecac895712abd3849b8da93b7b7dec61c7db5c9c7ebe40c/tiktoken-0.12.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6de0da39f605992649b9cfa6f84071e3f9ef2cec458d08c5feb1b6f0ff62e134", size = 1152856, upload-time = "2025-10-06T20:21:37.873Z" }, + { url = "https://files.pythonhosted.org/packages/54/35/388f941251b2521c70dd4c5958e598ea6d2c88e28445d2fb8189eecc1dfc/tiktoken-0.12.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6faa0534e0eefbcafaccb75927a4a380463a2eaa7e26000f0173b920e98b720a", size = 1195308, upload-time = "2025-10-06T20:21:39.577Z" }, + { url = "https://files.pythonhosted.org/packages/f8/00/c6681c7f833dd410576183715a530437a9873fa910265817081f65f9105f/tiktoken-0.12.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:82991e04fc860afb933efb63957affc7ad54f83e2216fe7d319007dab1ba5892", size = 1255697, upload-time = "2025-10-06T20:21:41.154Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d2/82e795a6a9bafa034bf26a58e68fe9a89eeaaa610d51dbeb22106ba04f0a/tiktoken-0.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:6fb2995b487c2e31acf0a9e17647e3b242235a20832642bb7a9d1a181c0c1bb1", size = 879375, upload-time = "2025-10-06T20:21:43.201Z" }, + { url = "https://files.pythonhosted.org/packages/de/46/21ea696b21f1d6d1efec8639c204bdf20fde8bafb351e1355c72c5d7de52/tiktoken-0.12.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e227c7f96925003487c33b1b32265fad2fbcec2b7cf4817afb76d416f40f6bb", size = 1051565, upload-time = "2025-10-06T20:21:44.566Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d9/35c5d2d9e22bb2a5f74ba48266fb56c63d76ae6f66e02feb628671c0283e/tiktoken-0.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c06cf0fcc24c2cb2adb5e185c7082a82cba29c17575e828518c2f11a01f445aa", size = 995284, upload-time = "2025-10-06T20:21:45.622Z" }, + { url = "https://files.pythonhosted.org/packages/01/84/961106c37b8e49b9fdcf33fe007bb3a8fdcc380c528b20cc7fbba80578b8/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f18f249b041851954217e9fd8e5c00b024ab2315ffda5ed77665a05fa91f42dc", size = 1129201, upload-time = "2025-10-06T20:21:47.074Z" }, + { url = "https://files.pythonhosted.org/packages/6a/d0/3d9275198e067f8b65076a68894bb52fd253875f3644f0a321a720277b8a/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47a5bc270b8c3db00bb46ece01ef34ad050e364b51d406b6f9730b64ac28eded", size = 1152444, upload-time = "2025-10-06T20:21:48.139Z" }, + { url = "https://files.pythonhosted.org/packages/78/db/a58e09687c1698a7c592e1038e01c206569b86a0377828d51635561f8ebf/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:508fa71810c0efdcd1b898fda574889ee62852989f7c1667414736bcb2b9a4bd", size = 1195080, upload-time = "2025-10-06T20:21:49.246Z" }, + { url = "https://files.pythonhosted.org/packages/9e/1b/a9e4d2bf91d515c0f74afc526fd773a812232dd6cda33ebea7f531202325/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1af81a6c44f008cba48494089dd98cccb8b313f55e961a52f5b222d1e507967", size = 1255240, upload-time = "2025-10-06T20:21:50.274Z" }, + { url = "https://files.pythonhosted.org/packages/9d/15/963819345f1b1fb0809070a79e9dd96938d4ca41297367d471733e79c76c/tiktoken-0.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:3e68e3e593637b53e56f7237be560f7a394451cb8c11079755e80ae64b9e6def", size = 879422, upload-time = "2025-10-06T20:21:51.734Z" }, + { url = "https://files.pythonhosted.org/packages/a4/85/be65d39d6b647c79800fd9d29241d081d4eeb06271f383bb87200d74cf76/tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8", size = 1050728, upload-time = "2025-10-06T20:21:52.756Z" }, + { url = "https://files.pythonhosted.org/packages/4a/42/6573e9129bc55c9bf7300b3a35bef2c6b9117018acca0dc760ac2d93dffe/tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b", size = 994049, upload-time = "2025-10-06T20:21:53.782Z" }, + { url = "https://files.pythonhosted.org/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37", size = 1129008, upload-time = "2025-10-06T20:21:54.832Z" }, + { url = "https://files.pythonhosted.org/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad", size = 1152665, upload-time = "2025-10-06T20:21:56.129Z" }, + { url = "https://files.pythonhosted.org/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5", size = 1194230, upload-time = "2025-10-06T20:21:57.546Z" }, + { url = "https://files.pythonhosted.org/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" }, + { url = "https://files.pythonhosted.org/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" }, + { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" }, + { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" }, + { url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" }, + { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" }, + { url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" }, + { url = "https://files.pythonhosted.org/packages/8e/32/45d02e2e0ea2be3a9ed22afc47d93741247e75018aac967b713b2941f8ea/tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697", size = 879117, upload-time = "2025-10-06T20:22:08.418Z" }, + { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" }, + { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" }, + { url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" }, + { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" }, + { url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" }, + { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" }, + { url = "https://files.pythonhosted.org/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" }, +] + [[package]] name = "timm" version = "1.0.25" @@ -3403,7 +3611,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d7/2c/593109822fe735e637382aca6640c1102c19797f7791f1fd1dab2d6c3cb1/timm-1.0.25.tar.gz", hash = "sha256:47f59fc2754725735cc81bb83bcbfce5bec4ebd5d4bb9e69da57daa92fcfa768", size = 2414743, upload-time = "2026-02-23T16:49:00.137Z" } @@ -3491,15 +3699,63 @@ name = "torch" version = "2.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, + { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" }, + { url = "https://files.pythonhosted.org/packages/76/bb/d820f90e69cda6c8169b32a0c6a3ab7b17bf7990b8f2c680077c24a3c14c/torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d", size = 79411450, upload-time = "2026-01-21T16:25:30.692Z" }, + { url = "https://files.pythonhosted.org/packages/78/89/f5554b13ebd71e05c0b002f95148033e730d3f7067f67423026cc9c69410/torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4", size = 145992610, upload-time = "2026-01-21T16:25:26.327Z" }, + { url = "https://files.pythonhosted.org/packages/ae/30/a3a2120621bf9c17779b169fc17e3dc29b230c29d0f8222f499f5e159aa8/torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763", size = 915607863, upload-time = "2026-01-21T16:25:06.696Z" }, + { url = "https://files.pythonhosted.org/packages/6f/3d/c87b33c5f260a2a8ad68da7147e105f05868c281c63d65ed85aa4da98c66/torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd", size = 113723116, upload-time = "2026-01-21T16:25:21.916Z" }, + { url = "https://files.pythonhosted.org/packages/61/d8/15b9d9d3a6b0c01b883787bd056acbe5cc321090d4b216d3ea89a8fcfdf3/torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b", size = 79423461, upload-time = "2026-01-21T16:24:50.266Z" }, + { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, + { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, + { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" }, + { url = "https://files.pythonhosted.org/packages/c9/5c/dee910b87c4d5c0fcb41b50839ae04df87c1cfc663cf1b5fca7ea565eeaa/torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294", size = 79498198, upload-time = "2026-01-21T16:24:34.704Z" }, + { url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, + { url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, + { url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, + { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, + { url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, + { url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, + { url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, + { url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, +] [[package]] name = "torch-geometric" @@ -3529,7 +3785,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6f/36/574c0c46e818533b78b3c09505211162918188325ab4165ef11a3f295755/torchprofile-0.0.4.tar.gz", hash = "sha256:96b6da17d752a06b02977e078aea95614893b31d4117dd5dcd081f30ce65611b", size = 4557, upload-time = "2021-06-22T04:58:03.592Z" } @@ -3545,7 +3801,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pillow" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/50/ae/cbf727421eb73f1cf907fbe5788326a08f111b3f6b6ddca15426b53fec9a/torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a95c47abb817d4e90ea1a8e57bd0d728e3e6b533b3495ae77d84d883c4d11f56", size = 1874919, upload-time = "2026-01-21T16:27:47.617Z" }, @@ -3638,6 +3894,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" }, ] +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From eda33a681597f6f284b3fae819afd40537dbc64f Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 12:01:06 -0700 Subject: [PATCH 14/24] update: categorize open items by implementation status in dflash.md - Not Yet Implemented: offline training, MoE, MLA, Docker local - Implemented but Not Validated E2E: sliding window, FP8/NVFP4, resume - Validated: online training, multi-node, AR eval, vLLM, export, decay Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 91 ++++++++------------- 1 file changed, 36 insertions(+), 55 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 2abc66c97d..09484af6e4 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -32,7 +32,7 @@ Target Model (frozen) **Draft model components** (Qwen3-based): - `Qwen3MLP`, `Qwen3RMSNorm`, `Qwen3RotaryEmbedding` from transformers -- Sliding window attention supported via `config.layer_types` +- Sliding window attention supported via `config.layer_types` *(implemented, not yet validated end-to-end)* - Independent of target model architecture ## Training @@ -172,57 +172,38 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories. ## Open Items -### Offline Training - -Online training requires the full target model in GPU memory alongside the draft model. -Offline training would pre-compute target hidden states and train the draft model separately. - -**Challenge**: DFlash needs hidden states from multiple target layers (not just the last) -at all positions for KV injection. EAGLE3 offline only stores last-layer hidden states -and reruns `lm_head` during training, but DFlash's feature fusion concatenates hidden -states from layers [1, 9, 17, 25, 33] — 5x the storage per position. - -**Potential approaches:** -- Store only the fused (post-FC) target hidden states instead of raw multi-layer states -- Pre-sample anchor positions and store only relevant slices -- Hybrid: quantized base model on CPU computes hidden states on-the-fly, draft on GPU - -### Model Support Expansion - -Currently supports Qwen3 draft architecture. See `hf_dflash.py` module docstring for -instructions on adding: -- **Qwen3MoE**: Replace MLP with `Qwen3MoeMLP` via config flag -- **MLA (DeepseekV3/Kimi-K2)**: Requires MLA-aware KV injection with compressed K/V - -### FP8 / NVFP4 Quantization - -The DFlash export pipeline supports quantized checkpoints via ModelOpt PTQ, following -the same flow as EAGLE3: - -1. Train draft model (bf16) -2. Apply PTQ: `mtq.quantize(model, quant_cfg)` with `FP8_DEFAULT_CFG` or `NVFP4_DEFAULT_CFG` -3. Export: `export_hf_checkpoint.py` auto-detects quantization and writes scales + `quantization_config` - -The exporter's `has_quant_opt()` check and `_export_transformers_checkpoint()` handle -quantized weights transparently. No DFlash-specific quantization code is needed. - -TODO: Add a quantization recipe/script and validate FP8/NVFP4 AR impact. - -### vLLM Deployment - -DFlash speculative decoding is supported in vLLM nightly (v0.19.1+): - -```bash -vllm serve Qwen/Qwen3-8B \ - --speculative-config '{"method": "dflash", "model": "path/to/dflash-checkpoint", "num_speculative_tokens": 7}' \ - --max-num-batched-tokens 32768 -``` - -Note: requires `vllm/vllm-openai:nightly` — the `latest` tag (v0.19.0) does not include DFlash. -See [`tools/launcher/common/dflash/vllm_serve.sh`](../../../tools/launcher/common/dflash/vllm_serve.sh) -for serve + benchmark scripts. - -### Docker Local Testing - -The launcher example currently requires Slurm cluster access. A local Docker example -with `hf_local=` path mapping would enable development without cluster access. +### Not Yet Implemented + +- **Offline training**: DFlash needs multi-layer hidden states at all positions for KV + injection (5x storage vs EAGLE3's single-layer approach). Possible approaches: store + fused hidden states, pre-sample anchors, or hybrid CPU base + GPU draft. +- **Qwen3MoE draft**: Replace `Qwen3MLP` with `Qwen3MoeMLP` via config flag. See + `hf_dflash.py` module docstring for instructions. +- **MLA support (DeepseekV3/Kimi-K2)**: Requires MLA-aware KV injection with compressed K/V. +- **Docker local testing**: Launcher example requires Slurm. Need a local Docker example + with `hf_local=` path mapping. + +### Implemented but Not Yet Validated End-to-End + +- **Sliding window attention**: Code reads `config.layer_types` and sets `sliding_window` + per layer. Unit tested but not validated in a full training run with sliding window models. +- **FP8 / NVFP4 quantization**: Export pipeline supports quantized checkpoints via + `hf_ptq.py` (PTQ succeeded in testing). AR impact of quantization not yet measured. + The flow: train (bf16) → `mtq.quantize(model, quant_cfg)` → `export_hf_checkpoint.py`. +- **Checkpoint resume**: `DFlashModule._apply()` handles meta-tensor rotary buffers. + Validated in training runs but not covered by integration tests. + +### Validated + +- **Online training**: E2E pipeline (train → export → eval) on sample-1K and sample-10K. +- **Multi-node DDP**: 8-node (64 GPU) training on full dataset, 10 epochs. +- **AR evaluation**: `ar_validate.py` with online GT, per-category MT-Bench. +- **vLLM deployment**: Speculative decoding with `vllm/vllm-openai:nightly` (v0.19.1+). + 3.1x speedup over baseline. Per-category benchmarks on MT-Bench. + ```bash + vllm serve Qwen/Qwen3-8B \ + --speculative-config '{"method": "dflash", "model": "path/to/checkpoint", "num_speculative_tokens": 7}' \ + --max-num-batched-tokens 32768 + ``` +- **Export**: z-lab compatible HF format, loadable by vLLM and z-lab benchmark. +- **Loss decay**: Validated +0.12 AR improvement with gamma=7 (bs16). From 2722877bbce53960d71f447917e468037413c101 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 12:13:57 -0700 Subject: [PATCH 15/24] fix: address CodeRabbit review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace print() with logging.getLogger() in hf_dflash.py - Fix base_token.item() batch safety (use base_token[0].item()) - Fix hardcoded torch_dtype in DFlash export (read from config) - Add TODO for f32→bf16 mask cast optimization Co-Authored-By: Claude Opus 4.6 (1M context) --- modelopt/torch/export/plugins/hf_spec_export.py | 2 +- modelopt/torch/speculative/plugins/hf_dflash.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index b107a3cb5a..f298d86be1 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -316,7 +316,7 @@ def _export_config(self): ), "rope_scaling": getattr(base_config, "rope_scaling", None), "tie_word_embeddings": False, - "torch_dtype": "bfloat16", + "torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace("torch.", ""), "num_target_layers": getattr(base_config, "num_hidden_layers", 36), } diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 20add0f7c9..85427befc2 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -52,11 +52,15 @@ lazy rope pattern needed for MLA models. """ +import logging + import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig, PreTrainedModel +logger = logging.getLogger(__name__) + # DFlash draft model uses Qwen3 components regardless of the target model. # This matches z-lab's implementation which inherits from Qwen3PreTrainedModel. from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS @@ -513,7 +517,7 @@ def modify(self, config): if mask_id is None: mask_id = self._auto_detect_mask_token_id(base_config) self.mask_token_id = mask_id[0] if isinstance(mask_id, list) else mask_id - print(f"DFlash mask_token_id: {self.mask_token_id}") + logger.info("DFlash mask_token_id: %s", self.mask_token_id) # Freeze base model if self.dflash_freeze_base_model: @@ -556,7 +560,7 @@ def modify(self, config): # Last resort: use the class two levels up (skip DFlash wrapper + DynamicModule) original_cls = type(self).__mro__[2] self._original_forward_cls = original_cls - print(f"DFlash: using {original_cls.__name__}.forward as base forward") + logger.info("DFlash: using %s.forward as base forward", original_cls.__name__) def get_exporter(self): """Get the exporter for the DFlash draft model.""" @@ -886,13 +890,14 @@ def pseudo_speculative_generate(self, input_ids, steps=1): th_dbg = torch.cat(sel, dim=-1) n_layers = len(base_outputs.hidden_states) th_norm = th_dbg.norm().item() - print( - f"[psg] hidden layers: {n_layers}, target_hidden: {th_dbg.shape}, norm: {th_norm:.2f}" + logger.info( + "[psg] hidden layers: %d, target_hidden: %s, norm: %.2f", + n_layers, th_dbg.shape, th_norm, ) - print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") + logger.info("[psg] base_token: %d, mask_token_id: %s", base_token[0].item(), self.mask_token_id) seq_len = input_ids.shape[1] blk = self.dflash_block_size - print(f"[psg] pos: ctx=[0..{seq_len - 1}], blk=[{seq_len}..{seq_len + blk - 1}]") + logger.info("[psg] pos: ctx=[0..%d], blk=[%d..%d]", seq_len - 1, seq_len, seq_len + blk - 1) selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] target_hidden = torch.cat(selected, dim=-1) From b9904396960c343ccfbc8ed11f55e9f77748c793 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 12:15:11 -0700 Subject: [PATCH 16/24] add: illustrated anchor sampling example in dflash.md Shows fixed blocks (38% efficiency) vs random anchors (100% efficiency) with concrete loss_mask example. Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 25 +++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 09484af6e4..f40c56ad3c 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -63,6 +63,31 @@ During training, anchor positions are sampled randomly from valid (assistant res tokens in each batch, rather than dividing the sequence into fixed blocks. Each anchor starts a block of `block_size` tokens where the draft model predicts positions 1..B-1. +``` +Sequence: [SYS] You helpful [USR] What 2+3? [AST] The answer is 5 +Position: 0 1 2 3 4 5 6 7 8 9 10 +loss_mask: 0 0 0 0 0 0 0 1 1 1 1 + ^^^^^^^^^^^^^^^^ + assistant response + +Fixed blocks (block_size=4): +Block 0: pos [0,1,2,3] anchor=0 → predict 1,2,3 → loss_mask=0,0,0 → ZERO LOSS +Block 1: pos [4,5,6,7] anchor=4 → predict 5,6,7 → loss_mask=0,0,1 → 1/3 useful +Block 2: pos [8,9,10,—] anchor=8 → predict 9,10,— → loss_mask=1,1,— → 2/2 useful + +Efficiency: 3/8 = 38% + +Random anchors (num_anchors=3, sampled from loss_mask=1): +Anchor 7: pos [7,8,9,10] → predict 8,9,10 → loss_mask=1,1,1 → 3/3 useful +Anchor 9: pos [9,10,—,—] → predict 10,—,— → loss_mask=1,—,— → 1/1 useful +Anchor 8: pos [8,9,10,—] → predict 9,10,— → loss_mask=1,1,— → 2/2 useful + +Efficiency: 6/6 = 100% +``` + +Random anchors guarantee every prediction is on assistant tokens. +Fixed blocks waste compute on prompt tokens where loss_mask=0. + **Tradeoff:** Higher `num_anchors` = more training signal per sample but more compute. Lower = faster iteration but less data efficiency. With `seq_len=4096` and `block_size=8`, `num_anchors=512` means the model sees ~512 blocks per sample (covering ~4096 positions). From 6a94085dd2d6a895fcc37c47241461f8f5135625 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 12:17:01 -0700 Subject: [PATCH 17/24] clarify: architecture descriptions per reviewer feedback - Explain "noise" = block embeddings (anchor + mask tokens), not noise in ML sense - Explain why Q from block only, K/V from context + block - Explain parallel drafting benefit (one pass = B-1 tokens vs autoregressive) - Explain random anchor = ground truth token, not random prediction - Add cross-reference to illustrated anchor example Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 23 ++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index f40c56ad3c..8b8f7ee09d 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -22,13 +22,22 @@ Target Model (frozen) ``` **Key components:** -- **Feature Fusion**: Multi-layer hidden states → Linear(num_layers × hidden_size, hidden_size) + RMSNorm -- **KV Injection**: In each draft decoder layer, K/V = concat(k_proj(target_hidden), k_proj(noise)) - with QK-norm. Q comes from noise only. -- **Parallel Drafting**: Position 0 is the anchor (known token), positions 1..B-1 are mask tokens - predicted in parallel. Bidirectional attention within the block. -- **Random Anchor Sampling**: During training, anchor positions are sampled randomly from - valid (assistant response) positions, not uniformly spaced. +- **Feature Fusion**: Multi-layer hidden states → Linear(num_layers × hidden_size, hidden_size) + RMSNorm. + Unlike EAGLE3 which uses a single layer's hidden state, DFlash concatenates hidden states from + multiple target layers (e.g., layers 1, 9, 17, 25, 33) to give the draft model richer context. +- **KV Injection**: In each draft decoder layer, K and V are projected from the concatenation of + target hidden states and the block's own embeddings. Q is projected from the block embeddings only + (the `[anchor, mask, mask, ...]` token embeddings). This lets the draft model attend to the + full target context while generating all block positions in parallel. +- **Parallel Drafting**: Position 0 is the anchor (the last accepted token — known and correct), + positions 1..B-1 are filled with a special mask token (similar to BERT's `[MASK]`). The draft + model predicts all B-1 unknown positions in a single forward pass, unlike autoregressive drafters + (EAGLE3) which predict one token at a time. Benefit: one forward pass produces B-1 draft tokens. +- **Random Anchor Sampling**: During training, anchors are sampled randomly from assistant response + positions (where `loss_mask=1`), not placed at fixed intervals. The anchor is the starting token + of each training block — it's always correct (from the ground truth) and the model learns to + predict the next B-1 tokens given this anchor and the target's hidden states. See the + [illustrated example](#random-anchor-sampling-num_anchors) below for why this improves efficiency. **Draft model components** (Qwen3-based): - `Qwen3MLP`, `Qwen3RMSNorm`, `Qwen3RotaryEmbedding` from transformers From 5b9c5094eb68723dd37475e06f772c3ab0484b99 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 12:37:18 -0700 Subject: [PATCH 18/24] add: token-level KV injection illustration in dflash.md Shows how Q (block only) attends to K/V (context + block) with concrete token example: "The answer is" + block [is, MASK, MASK, MASK]. Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 52 +++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 8b8f7ee09d..694bef57aa 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -39,6 +39,58 @@ Target Model (frozen) predict the next B-1 tokens given this anchor and the target's hidden states. See the [illustrated example](#random-anchor-sampling-num_anchors) below for why this improves efficiency. +**KV Injection (token-level example):** + +Given context `"The answer is"` and block_size=4 with anchor `"is"`: + +``` +Target model hidden states (from frozen base model): + h["The"] h["answer"] h["is"] ← target_hidden (ctx_len=3) + │ │ │ + └──── FC + RMSNorm ────┘ + │ + fused context features + +Block input (draft token embeddings): + embed("is") embed(MASK) embed(MASK) embed(MASK) ← noise_embedding (block_size=4) + pos=3 pos=4 pos=5 pos=6 + +In each DFlash decoder layer: + Q = q_proj(noise_embedding) ← shape [4, head_dim] + only the block tokens generate queries + + K = concat( ← shape [7, head_dim] + k_proj(fused_context), ← from target hidden [3 positions: "The","answer","is"] + k_proj(noise_embedding) ← from block tokens [4 positions: "is",MASK,MASK,MASK] + ) + + V = concat(v_proj(fused_context), v_proj(noise_embedding)) ← same shape as K + + Attention: Q (4 tokens) attends to K/V (7 tokens) + ┌─────────────────────────────────────────────────────────────┐ + │ K/V positions │ + │ context (from target) │ block (from draft) │ + │ "The" "answer" "is" │ "is" MASK MASK MASK │ + │ pos=0 pos=1 pos=2 │ pos=3 pos=4 pos=5 pos=6 │ + ├───────────────────────────────┼─────────────────────────────┤ + │ Q pos=3 ("is"): attends to all 7 K/V positions │ + │ Q pos=4 (MASK): attends to all 7 K/V positions │ + │ Q pos=5 (MASK): attends to all 7 K/V positions │ + │ Q pos=6 (MASK): attends to all 7 K/V positions │ + └─────────────────────────────────────────────────────────────┘ + (bidirectional within block, no attention mask at inference) + + Output → lm_head → predictions: + pos=3: skip (anchor, already known) + pos=4: predict token after "is" → "5" + pos=5: predict token after "is 5" → "." + pos=6: predict token after "is 5." → "[EOS]" +``` + +The draft model sees the target's internal representation of the context (via KV injection) +without re-running the target model. This is what makes DFlash efficient — the expensive +target model forward pass happens once, and the lightweight draft model reuses its hidden states. + **Draft model components** (Qwen3-based): - `Qwen3MLP`, `Qwen3RMSNorm`, `Qwen3RotaryEmbedding` from transformers - Sliding window attention supported via `config.layer_types` *(implemented, not yet validated end-to-end)* From 83f398a86cccba0b1ad2a040f7b68be239238910 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 12:43:36 -0700 Subject: [PATCH 19/24] add: training vs inference with attention mask visualization Training: 2-block attention mask showing context visibility rules - Block 0: no context, bidirectional within block - Block 1: sees context before anchor, bidirectional within block Inference: no mask, full context visibility Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 74 +++++++++++++++++---- 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 694bef57aa..c6e27977c4 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -67,17 +67,15 @@ In each DFlash decoder layer: V = concat(v_proj(fused_context), v_proj(noise_embedding)) ← same shape as K Attention: Q (4 tokens) attends to K/V (7 tokens) - ┌─────────────────────────────────────────────────────────────┐ - │ K/V positions │ - │ context (from target) │ block (from draft) │ - │ "The" "answer" "is" │ "is" MASK MASK MASK │ - │ pos=0 pos=1 pos=2 │ pos=3 pos=4 pos=5 pos=6 │ - ├───────────────────────────────┼─────────────────────────────┤ - │ Q pos=3 ("is"): attends to all 7 K/V positions │ - │ Q pos=4 (MASK): attends to all 7 K/V positions │ - │ Q pos=5 (MASK): attends to all 7 K/V positions │ - │ Q pos=6 (MASK): attends to all 7 K/V positions │ - └─────────────────────────────────────────────────────────────┘ + + K/V: "The" "answer" "is" │ "is" MASK MASK MASK + pos0 pos1 pos2 │ pos3 pos4 pos5 pos6 + ───────────────────────────┼────────────────────────── + Q pos=3 "is" : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + Q pos=4 MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + Q pos=5 MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + Q pos=6 MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + ─── context ─── │ ──── block ──────────── (bidirectional within block, no attention mask at inference) Output → lm_head → predictions: @@ -87,9 +85,59 @@ In each DFlash decoder layer: pos=6: predict token after "is 5." → "[EOS]" ``` +**Training vs Inference:** + +``` +TRAINING (2 anchors, block_size=4): + + Context tokens: "The" "answer" "is" "5" "." + Block 0 (anchor="The"): [The, MASK, MASK, MASK] + Block 1 (anchor="is"): [is, MASK, MASK, MASK] + + All blocks processed in ONE forward pass. Attention mask controls visibility: + + K/V (context) K/V (block 0) K/V (block 1) + "The" "ans" "is" "5" "." The M M M is M M M + c0 c1 c2 c3 c4 b0 b1 b2 b3 b4 b5 b6 b7 + Q ───────────────────────────────────────────────────────────────────────── + b0 "The" : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b1 MASK : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b2 MASK : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b3 MASK : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b4 "is" : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + b5 MASK : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + b6 MASK : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + b7 MASK : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + ── context ────── ── block 0 ────── ── block 1 ────── + + Block 0: first block sees NO context (✗), only its own block (bidirectional ✓) + Block 1: sees context before anchor "is" (c0,c1 ✓), NOT its own anchor or later + plus its own block (bidirectional ✓) + + Loss: computed on all non-anchor positions simultaneously. + No verification — ground truth labels known from training data. + +INFERENCE (one block at a time, NO attention mask): + + Step 1: target forward("The answer is") → base_token = "5" + block = [5, MASK, MASK, MASK] + + K/V: "The" "ans" "is" │ "5" MASK MASK MASK + Q ─────────────────────────────────┼────────────────────────── + "5" : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + + All ✓ — no mask at inference. Block sees full context freely. + Target verifies → accept 3 → sequence: "The answer is 5 . [EOS]" + + Step 2: next block with grown context (5 tokens) ... +``` + The draft model sees the target's internal representation of the context (via KV injection) -without re-running the target model. This is what makes DFlash efficient — the expensive -target model forward pass happens once, and the lightweight draft model reuses its hidden states. +without re-running the target model for drafting. The expensive target forward pass is +only needed for verification — the lightweight draft model reuses the target's hidden states. **Draft model components** (Qwen3-based): - `Qwen3MLP`, `Qwen3RMSNorm`, `Qwen3RotaryEmbedding` from transformers From 747a048f300de0630231e68553ba7ee5328826a5 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 12:51:50 -0700 Subject: [PATCH 20/24] =?UTF-8?q?rename:=20AR=20Evaluation=20=E2=86=92=20H?= =?UTF-8?q?uggingFace=20AR=20Evaluation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/speculative_decoding/doc/dflash.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index c6e27977c4..79cb6ca461 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -251,7 +251,7 @@ Trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples), 64 GPUs, 10 ep | Total Steps | 306,620 | | Final Per-Token Acc | 67.0% | -### AR Evaluation +### HuggingFace AR Evaluation AR is evaluated using `ar_validate.py` which calls `pseudo_speculative_generate` with online (context-dependent) ground truth: From 6526c4a272b80e8264af516264efd17c27f576a4 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 16:24:24 -0700 Subject: [PATCH 21/24] address PR review: simplify DFlash plugin per reviewer feedback - _get_attn_fn: use ALL_ATTENTION_FUNCTIONS directly, remove _eager_attention - _apply override: replace with register_load_state_dict_post_hook (narrower scope) - attention masks: create directly in target dtype, matching EAGLE convention - default_config.py: move static defaults (hidden_act, etc.) from modify() - _auto_detect_mask_token_id: simplify to tokenizer-based, remove model-specific heuristics - remove dead code: _original_forward_cls, _base_forward, mlp_bias - update unit tests for load_state_dict hook Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../speculative/dflash/default_config.py | 11 +- .../torch/speculative/plugins/hf_dflash.py | 255 ++++++------------ .../speculative/plugins/test_hf_dflash.py | 29 +- 3 files changed, 101 insertions(+), 194 deletions(-) diff --git a/modelopt/torch/speculative/dflash/default_config.py b/modelopt/torch/speculative/dflash/default_config.py index 5536e0d4df..1777de3f29 100644 --- a/modelopt/torch/speculative/dflash/default_config.py +++ b/modelopt/torch/speculative/dflash/default_config.py @@ -16,13 +16,20 @@ """Default DFlash architecture config. Model-specific settings (hidden_size, num_attention_heads, rope_*, etc.) -are inherited from the base model in HFDFlashModel.modify(). Only -DFlash-specific defaults are set here. +are inherited from the base model in HFDFlashModel.modify(). Static +defaults that don't depend on the base model are set here, similar to +``eagle/default_config.py``. """ default_dflash_config = { + # DFlash-specific "num_hidden_layers": 5, + # Architecture defaults (overridable by user config) + "hidden_act": "silu", "rms_norm_eps": 1e-06, + "initializer_range": 0.02, "attention_bias": False, "attention_dropout": 0.0, + "tie_word_embeddings": False, + "_attn_implementation": "sdpa", } diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 85427befc2..d87eaf3b0a 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -63,14 +63,16 @@ # DFlash draft model uses Qwen3 components regardless of the target model. # This matches z-lab's implementation which inherits from Qwen3PreTrainedModel. -from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS -from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS -from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding as _ROTARY_CLS -from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half -from transformers.utils import ModelOutput +from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS # noqa: E402, N814 +from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS # noqa: E402, N814 +from transformers.models.qwen3.modeling_qwen3 import ( # noqa: E402 + Qwen3RotaryEmbedding as _ROTARY_CLS, # noqa: N814 +) +from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half # noqa: E402 +from transformers.utils import ModelOutput # noqa: E402 -from ..dflash.conversion import DFlashDMRegistry -from ..dflash.dflash_model import DFlashModel +from ..dflash.conversion import DFlashDMRegistry # noqa: E402 +from ..dflash.dflash_model import DFlashModel # noqa: E402 __all__ = ["HFDFlashModel"] @@ -136,38 +138,15 @@ def __init__(self, config, layer_idx): self.sliding_window = None def _get_attn_fn(self): - """Lazily resolve the HF attention function.""" + """Lazily resolve the HF attention function (default: sdpa).""" if self._attn_fn is not None: return self._attn_fn - try: - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - impl = getattr(self.config, "_attn_implementation", "eager") - if impl and impl != "eager" and impl in ALL_ATTENTION_FUNCTIONS: - self._attn_fn = ALL_ATTENTION_FUNCTIONS[impl] - else: - self._attn_fn = self._eager_attention - except (ImportError, AttributeError): - self._attn_fn = self._eager_attention + impl = self.config._attn_implementation # default set in dflash/default_config.py + self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"]) return self._attn_fn - def _eager_attention(self, module, q, k, v, attention_mask, **kwargs): - """Eager attention matching HF's eager_attention_forward.""" - scaling = kwargs.get("scaling", self.scaling) - n_rep = self.num_key_value_groups - if n_rep > 1: - k = k.repeat_interleave(n_rep, dim=1) - v = v.repeat_interleave(n_rep, dim=1) - attn_weights = torch.matmul(q, k.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - q.dtype - ) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): """Forward with KV injection. @@ -266,39 +245,21 @@ def __init__(self, config): self.rotary_emb = _ROTARY_CLS(config=config) self._rotary_config = config # Stored for re-creating rotary_emb on resume - # Initialize weights matching HF PreTrainedModel (normal_ with initializer_range) - # SpecForge's DFlashDraftModel uses Qwen3PreTrainedModel.post_init() which does this. + # Explicit weight init is needed because DFlashModule is instantiated via + # mtsp.convert() AFTER the base model's post_init() has already run, so HF's + # automatic _init_weights walk doesn't reach these new layers. self._init_weights(config) - def _apply(self, fn, recurse=True): - """Override _apply to fix meta-tensor rotary buffers before device transfer. - - Why this is needed: - When resuming from a checkpoint, ModelOpt's ``enable_huggingface_checkpointing`` - restores the model architecture from ``modelopt_state.pth``. During this restore, - ``DFlashModule.__init__`` runs and creates ``rotary_emb`` with its ``inv_freq`` - buffer. However, ``inv_freq`` is a *computed* buffer (derived from ``rope_theta`` - and ``head_dim``), not a learned parameter, so it is NOT saved in - ``model.safetensors``. After ``from_pretrained`` loads the saved weights, all - learned parameters are materialized on CPU, but ``inv_freq`` remains on the - **meta device** (a placeholder with shape but no data). - - Later, HF Trainer calls ``model.to(device)`` which internally calls ``_apply`` - on every submodule. When ``_apply`` reaches the meta ``inv_freq`` buffer, it - raises ``NotImplementedError: Cannot copy out of meta tensor``. - - Fix: - Before ``super()._apply()`` transfers tensors to the target device, we check - if ``rotary_emb`` has any meta buffers. If so, we re-create it on CPU using - the stored config (``_rotary_config``). This produces a real ``inv_freq`` tensor - with correct values, which ``_apply`` can then safely move to GPU. - - This approach is transparent to the training script (``main.py``) — no - mode-specific resume logic is needed there. - """ - if hasattr(self, "rotary_emb") and any(b.is_meta for b in self.rotary_emb.buffers()): - self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device="cpu") - return super()._apply(fn, recurse) + # Fix meta-tensor rotary buffers after checkpoint loading. + # On resume, inv_freq (a computed buffer, not saved in checkpoint) stays on + # meta device. Re-create rotary_emb on CPU so .to(device) can proceed. + self.register_load_state_dict_post_hook(self._fix_meta_rotary_buffers) + + @staticmethod + def _fix_meta_rotary_buffers(module, incompatible_keys): + """Re-create rotary_emb if its buffers are on meta device (post state_dict load hook).""" + if hasattr(module, "rotary_emb") and any(b.is_meta for b in module.rotary_emb.buffers()): + module.rotary_emb = _ROTARY_CLS(config=module._rotary_config, device="cpu") def _init_weights(self, config): """Initialize weights matching HF PreTrainedModel._init_weights.""" @@ -342,14 +303,9 @@ def create_dflash_attention_mask( full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) - # Create in f32 then cast, matching SpecForge. This ensures masked - # positions get -inf in bf16 (f32 min overflows to -inf when cast), - # not the largest finite negative bf16 value. - # TODO: This f32→bf16 cast pattern may be simplified once the mask behavior is stable. - # Consider creating directly in the target dtype with appropriate fill value. - full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=torch.float32) - full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) - full_mask = full_mask.to(dtype=dtype) + # Create additive mask directly in target dtype, matching EAGLE convention. + full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=dtype) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(dtype).min) return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] @@ -389,46 +345,26 @@ def _base_llm_config(self): @staticmethod def _auto_detect_mask_token_id(base_config): - """Auto-detect an appropriate mask token ID for DFlash. + """Auto-detect mask token ID from the base model's tokenizer. - Different model families use different strategies: - - Qwen3/3.5: built-in [MASK] token in vocabulary - - Llama3: reserved special tokens (128002 = reserved_special_token_0) - - Others: try tokenizer.mask_token_id, then fall back to pad/eos + Loads the tokenizer and returns ``tokenizer.mask_token_id`` if available. + Raises ValueError otherwise — the user must set mask_token_id explicitly. """ - model_type = getattr(base_config, "model_type", "") - vocab_size = getattr(base_config, "vocab_size", 0) - - # Qwen3/3.5: known mask token positions - if "qwen3" in model_type.lower() or "qwen" in model_type.lower(): - # Qwen3 vocab has dedicated mask tokens - # Qwen3.5-4B: 248070, Qwen3-8B: similar range - # Heuristic: eos_token_id + some offset, or check known values - eos = getattr(base_config, "eos_token_id", None) - if isinstance(eos, list): - eos = eos[0] - if eos and vocab_size > 200000: - # Large Qwen vocab — mask token is typically near end of special tokens - # Known: Qwen3.5 eos=248044, mask=248070 (offset ~26) - # Try common offsets - for offset in [26, 25, 24]: - candidate = eos + offset - if candidate < vocab_size: - return candidate - # Fallback for smaller Qwen models - if vocab_size > 150000: - return vocab_size - 250 # heuristic for Qwen special token region - - # Llama3: use reserved_special_token_0 (128002) - if "llama" in model_type.lower(): - if vocab_size >= 128256: # Llama3 vocab size - return 128002 # <|reserved_special_token_0|> - - # No suitable mask token found — user must provide one + from transformers import AutoTokenizer + + model_name = getattr(base_config, "_name_or_path", None) + if model_name: + try: + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if tokenizer.mask_token_id is not None: + return tokenizer.mask_token_id + except Exception: + pass + raise ValueError( - f"Cannot auto-detect mask_token_id for model_type='{model_type}'. " - f"Please set dflash_architecture_config.mask_token_id explicitly in your config. " - f"The mask token should be an unused special token (not eos or pad)." + "Cannot auto-detect mask_token_id. " + "Please set dflash_architecture_config.mask_token_id explicitly in your config. " + "The mask token should be an unused special token (not eos or pad)." ) def _find_base_model_parts(self): @@ -460,46 +396,34 @@ def modify(self, config): base_config = self._base_llm_config self.dflash_config = PretrainedConfig.from_dict(config.dflash_architecture_config) - # Inherit settings from base model, but only those NOT already in the user config. - # hidden_size and vocab_size MUST match. Others (heads, intermediate_size) can differ. - # This allows the draft model to have a different architecture than the base model. + # hidden_size and vocab_size MUST match the base model. self.dflash_config.hidden_size = base_config.hidden_size self.dflash_config.vocab_size = base_config.vocab_size - # These use base model defaults if not specified in dflash_architecture_config - for attr, default_from_base in [ - ("max_position_embeddings", True), - ("intermediate_size", True), - ("num_attention_heads", True), - ("num_key_value_heads", True), - ("hidden_act", True), - ("rope_theta", True), - ("rope_scaling", True), - ("rope_type", False), - ("position_embedding_type", False), - ("rope_interleaved", False), - ("rms_norm_eps", True), - ("attention_bias", False), - ("tie_word_embeddings", False), - ]: + # Inherit architecture settings from base model when not specified by user. + # Static defaults (hidden_act, attention_bias, etc.) are in dflash/default_config.py. + _base_model_attrs = [ + "max_position_embeddings", + "intermediate_size", + "num_attention_heads", + "num_key_value_heads", + "rope_theta", + "rope_scaling", + "rope_type", + "rope_interleaved", + "rms_norm_eps", + ] + for attr in _base_model_attrs: if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: - if default_from_base and hasattr(base_config, attr): + if hasattr(base_config, attr): setattr(self.dflash_config, attr, getattr(base_config, attr)) - # Ensure required attrs have defaults - if not hasattr(self.dflash_config, "mlp_bias") or self.dflash_config.mlp_bias is None: - self.dflash_config.mlp_bias = False - self.dflash_config.head_dim = getattr( self.dflash_config, "head_dim", self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, ) self.dflash_config.block_size = self.dflash_block_size - # Default to sdpa, matching SpecForge's DFlashDraftModel(Qwen3PreTrainedModel) - # which resolves to sdpa via post_init() - if self.dflash_config._attn_implementation is None: - self.dflash_config._attn_implementation = "sdpa" # Target layer IDs num_target_layers = base_config.num_hidden_layers @@ -509,10 +433,8 @@ def modify(self, config): # mask_token_id resolution order: # 1. Explicit in dflash_architecture_config (user override) - # 2. Auto-detect from model vocabulary: - # - Qwen3/3.5: built-in [MASK] token - # - Llama3: reserved_special_token_0 (128002) - # 3. Error — user must provide mask_token_id for unsupported models + # 2. Auto-detect from tokenizer (tokenizer.mask_token_id) + # 3. Error — user must provide mask_token_id mask_id = config.dflash_architecture_config.get("mask_token_id", None) if mask_id is None: mask_id = self._auto_detect_mask_token_id(base_config) @@ -534,44 +456,12 @@ def modify(self, config): self.is_quantized = False self._num_anchors = self.dflash_num_anchors - # Store bound reference to the original model class's forward. - # DynamicModule changes type(self) but the original class is in _original_cls. - # Find the original HF model class (e.g., Qwen3_5ForConditionalGeneration) - # by walking MRO and skipping DFlash/DynamicModule classes - skip_names = { - "HFDFlashModel", - "DFlashModel", - "DynamicModule", - "DFlashPreTrainedModel", - "DFlashDraftModel", - } - original_cls = None - for cls in type(self).__mro__: - if ( - hasattr(cls, "forward") - and cls.__name__ not in skip_names - and cls is not type(self) - and issubclass(cls, PreTrainedModel) - and cls is not PreTrainedModel - ): - original_cls = cls - break - if original_cls is None: - # Last resort: use the class two levels up (skip DFlash wrapper + DynamicModule) - original_cls = type(self).__mro__[2] - self._original_forward_cls = original_cls - logger.info("DFlash: using %s.forward as base forward", original_cls.__name__) - def get_exporter(self): """Get the exporter for the DFlash draft model.""" from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter return DFlashExporter(self) - def _base_forward(self, **kwargs): - """Call the original model's forward, bypassing DFlash wrapper.""" - return self._original_forward_cls.forward(self, **kwargs) - def _sample_anchor_positions(self, seq_len, loss_mask, device): """Randomly sample anchor positions per sample, matching SpecForge PR #473. @@ -726,9 +616,8 @@ def forward( # Convert bool mask to float additive mask for SDPA dtype = target_hidden.dtype - attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=torch.float32) - attn_mask.masked_fill_(~final_mask, torch.finfo(torch.float32).min) - attn_mask = attn_mask.to(dtype=dtype) + attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=dtype) + attn_mask.masked_fill_(~final_mask, torch.finfo(dtype).min) # 7. Draft forward hidden = self.dflash_module( @@ -892,12 +781,18 @@ def pseudo_speculative_generate(self, input_ids, steps=1): th_norm = th_dbg.norm().item() logger.info( "[psg] hidden layers: %d, target_hidden: %s, norm: %.2f", - n_layers, th_dbg.shape, th_norm, + n_layers, + th_dbg.shape, + th_norm, + ) + logger.info( + "[psg] base_token: %d, mask_token_id: %s", base_token[0].item(), self.mask_token_id ) - logger.info("[psg] base_token: %d, mask_token_id: %s", base_token[0].item(), self.mask_token_id) seq_len = input_ids.shape[1] blk = self.dflash_block_size - logger.info("[psg] pos: ctx=[0..%d], blk=[%d..%d]", seq_len - 1, seq_len, seq_len + blk - 1) + logger.info( + "[psg] pos: ctx=[0..%d], blk=[%d..%d]", seq_len - 1, seq_len, seq_len + blk - 1 + ) selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] target_hidden = torch.cat(selected, dim=-1) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 6bb6c0bb36..20c4edc3eb 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -139,40 +139,43 @@ def test_save_and_restore(self, tmp_path): tf_modelopt_state_and_output_tester(model_ref, model_test) -class TestDFlashApplyMetaFix: - """Test DFlashModule._apply handles meta-tensor rotary buffers. +class TestDFlashMetaRotaryFix: + """Test load_state_dict post-hook fixes meta-tensor rotary buffers. During checkpoint restore, rotary inv_freq buffers may be on meta device - (they are computed, not saved). _apply should re-create them on CPU. + (they are computed, not saved). The post-hook should re-create them on CPU. """ - def test_apply_recreates_meta_rotary(self): - """Test that .to() recreates rotary_emb when buffers are on meta device.""" + def test_load_state_dict_fixes_meta_rotary(self): + """Test that load_state_dict recreates rotary_emb when buffers are on meta device.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) dflash_mod = model.dflash_module + sd = dflash_mod.state_dict() + # Simulate meta buffers (as happens during checkpoint restore) for name, buf in list(dflash_mod.rotary_emb.named_buffers()): dflash_mod.rotary_emb._buffers[name] = torch.empty_like(buf, device="meta") assert any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) - # .to() triggers _apply which should fix meta buffers - dflash_mod.to("cpu") + # load_state_dict triggers the post-hook which should fix meta buffers + dflash_mod.load_state_dict(sd, strict=False) assert not any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) - def test_apply_noop_when_no_meta(self): - """Test that .to() does not recreate rotary_emb when buffers are normal.""" + def test_load_state_dict_noop_when_no_meta(self): + """Test that load_state_dict does not recreate rotary_emb when buffers are normal.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) dflash_mod = model.dflash_module + sd = dflash_mod.state_dict() rotary_id_before = id(dflash_mod.rotary_emb) - dflash_mod.to("cpu") + dflash_mod.load_state_dict(sd, strict=False) assert id(dflash_mod.rotary_emb) == rotary_id_before @@ -268,9 +271,10 @@ class TestDFlashSlidingWindow: def test_sliding_window_from_config(self): """Test DFlashAttention reads sliding_window from config.layer_types.""" - from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention from transformers import PretrainedConfig + from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention + config = PretrainedConfig( hidden_size=64, num_attention_heads=4, @@ -290,9 +294,10 @@ def test_sliding_window_from_config(self): def test_no_sliding_window_without_config(self): """Test DFlashAttention defaults to no sliding window.""" - from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention from transformers import PretrainedConfig + from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention + config = PretrainedConfig( hidden_size=64, num_attention_heads=4, From 0dafbced9dc42b567fb90dda42c7246cebc5c219 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 19:04:44 -0700 Subject: [PATCH 22/24] address PR review: refactor helpers, fix resume, reorganize scripts Code changes: - Extract _build_noise_embedding, _build_position_ids, _build_draft_attention_mask, _compute_loss helpers from forward() - Combine labels and attention_mask for loss mask (LabelSmoother.ignore_index) - Restore _apply override with one-shot flag for meta rotary fix (register_load_state_dict_post_hook didn't work: hook fires on DFlashModule but from_pretrained loads state dict on the parent model) - Restore --trust_remote_code in export_hf_checkpoint.py - Auto-export last checkpoint after training in dflash_online_training.sh Script reorganization: - Move ar_eval_mtbench.sh to common/specdec/ (mode-agnostic) - Move online_training.sh to common/specdec/dflash_online_training.sh - Add common/specdec/vllm_smoke_test.sh (generalized for eagle/dflash) - Remove common/dflash/export.sh (merged into training script) - Remove common/dflash/ptq_and_export.sh (deferred to next PR) - Replace common/dflash/vllm_serve.sh with generalized smoke test Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../scripts/export_hf_checkpoint.py | 5 +- .../torch/speculative/plugins/hf_dflash.py | 258 +++++++++++------- .../speculative/plugins/test_hf_dflash.py | 25 +- tools/launcher/common/dflash/export.sh | 63 ----- .../launcher/common/dflash/ptq_and_export.sh | 59 ---- tools/launcher/common/dflash/vllm_serve.sh | 144 ---------- .../{dflash => specdec}/ar_eval_mtbench.sh | 0 .../dflash_online_training.sh} | 28 ++ .../common/specdec/vllm_smoke_test.sh | 110 ++++++++ .../Qwen/Qwen3-8B/hf_online_dflash.yaml | 4 +- 10 files changed, 314 insertions(+), 382 deletions(-) delete mode 100644 tools/launcher/common/dflash/export.sh delete mode 100644 tools/launcher/common/dflash/ptq_and_export.sh delete mode 100644 tools/launcher/common/dflash/vllm_serve.sh rename tools/launcher/common/{dflash => specdec}/ar_eval_mtbench.sh (100%) rename tools/launcher/common/{dflash/online_training.sh => specdec/dflash_online_training.sh} (77%) create mode 100644 tools/launcher/common/specdec/vllm_smoke_test.sh diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index 925f4b73d0..c3ca75cc24 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -29,6 +29,7 @@ def parse_args(): description="Export a HF checkpoint (with ModelOpt state) for deployment." ) parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.") + parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code") parser.add_argument( "--export_path", type=str, default="Destination directory for exported files." ) @@ -38,7 +39,9 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -model = load_vlm_or_llm(args.model_path, torch_dtype="auto") +model = load_vlm_or_llm( + args.model_path, torch_dtype="auto", trust_remote_code=args.trust_remote_code +) model.eval() with torch.inference_mode(): export_speculative_decoding( diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index d87eaf3b0a..d5df9cdb36 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -58,6 +58,7 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig, PreTrainedModel +from transformers.trainer_pt_utils import LabelSmoother logger = logging.getLogger(__name__) @@ -250,16 +251,20 @@ def __init__(self, config): # automatic _init_weights walk doesn't reach these new layers. self._init_weights(config) - # Fix meta-tensor rotary buffers after checkpoint loading. - # On resume, inv_freq (a computed buffer, not saved in checkpoint) stays on - # meta device. Re-create rotary_emb on CPU so .to(device) can proceed. - self.register_load_state_dict_post_hook(self._fix_meta_rotary_buffers) + self._meta_fixed = False - @staticmethod - def _fix_meta_rotary_buffers(module, incompatible_keys): - """Re-create rotary_emb if its buffers are on meta device (post state_dict load hook).""" - if hasattr(module, "rotary_emb") and any(b.is_meta for b in module.rotary_emb.buffers()): - module.rotary_emb = _ROTARY_CLS(config=module._rotary_config, device="cpu") + def _apply(self, fn, recurse=True): + """Fix meta-tensor rotary buffers on first .to(device) call, then no-op. + + On checkpoint resume, inv_freq (a computed buffer, not saved in checkpoint) + stays on meta device. Re-create rotary_emb on CPU before the first device + transfer so .to(device) can proceed. + """ + if not self._meta_fixed and hasattr(self, "rotary_emb"): + if any(b.is_meta for b in self.rotary_emb.buffers()): + self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device="cpu") + self._meta_fixed = True + return super()._apply(fn, recurse) def _init_weights(self, config): """Initialize weights matching HF PreTrainedModel._init_weights.""" @@ -466,6 +471,10 @@ def _sample_anchor_positions(self, seq_len, loss_mask, device): """Randomly sample anchor positions per sample, matching SpecForge PR #473. Returns (anchor_positions [B, N], block_keep_mask [B, N]). + + Note: Fixed (uniform) anchors would allow caching masks/positions across steps, + but random anchors provide data augmentation during training. A fixed-anchor mode + could be added as a future optimization for inference or fine-tuning. """ bs = self.dflash_block_size bsz = loss_mask.shape[0] @@ -498,80 +507,12 @@ def _sample_anchor_positions(self, seq_len, loss_mask, device): anchors = torch.where(keep, anchors, torch.tensor(0, dtype=torch.long, device=device)) return anchors, keep - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - **kwargs, - ): - """Training forward matching SpecForge latest (post-PR #473). - - Key changes from original PR #415: - - Random anchor sampling instead of uniform block division - - Bidirectional intra-block attention (no causal constraint) - - Context sees strictly before anchor position - - Label alignment: position k predicts token at anchor+k - - Optional loss decay weighting - """ - if not self.training: - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - + def _build_noise_embedding(self, input_ids, anchor_positions, block_keep_mask, n_blocks): + """Build noise embeddings: anchor token at block start, mask_token elsewhere.""" bsz, seq_len = input_ids.shape block_size = self.dflash_block_size device = input_ids.device - # 1. Run base model → hidden states - with torch.no_grad(): - base_outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - ) - - offset = 1 - selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] - target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] - - # 2. Build loss mask from labels or attention_mask - if labels is not None: - loss_mask = (labels != -100).float() - elif attention_mask is not None: - loss_mask = attention_mask.float() - else: - loss_mask = torch.ones(bsz, seq_len, device=device) - - # 3. Random anchor sampling (SpecForge PR #463/#473) - anchor_positions, block_keep_mask = self._sample_anchor_positions( - seq_len, loss_mask, device - ) - n_blocks = anchor_positions.shape[1] - - if n_blocks == 0 or not block_keep_mask.any(): - # Zero loss that still flows through dflash_module for DDP gradient sync - dummy = self.dflash_module.fc.weight.sum() * 0.0 - return ModelOutput(loss=dummy, logits=base_outputs.logits, train_acc=[[0.0]]) - - # 4. Create noise embeddings: anchor token at block start, mask_token elsewhere noise_ids = torch.full( (bsz, n_blocks * block_size), self.mask_token_id, dtype=torch.long, device=device ) @@ -585,15 +526,24 @@ def forward( anchor_tokens, torch.tensor(self.mask_token_id, dtype=torch.long, device=device), ) - noise_embedding = self._base_model_embeddings(noise_ids) + return self._base_model_embeddings(noise_ids) + + def _build_position_ids(self, seq_len, anchor_positions, device): + """Build position IDs: context [0..S-1], draft blocks [anchor+0..anchor+B-1].""" + bsz = anchor_positions.shape[0] + block_size = self.dflash_block_size - # 5. Position IDs: context [0..S-1], draft blocks [anchor+0..anchor+B-1] ctx_pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) offsets = torch.arange(block_size, device=device).view(1, 1, -1) draft_pos = (anchor_positions.unsqueeze(-1) + offsets).view(bsz, -1) - full_pos = torch.cat([ctx_pos, draft_pos], dim=1) + return torch.cat([ctx_pos, draft_pos], dim=1) - # 6. Attention mask: SDPA bool mask [B, 1, Q_LEN, KV_LEN] + def _build_draft_attention_mask( + self, seq_len, anchor_positions, block_keep_mask, n_blocks, dtype, device + ): + """Build SDPA attention mask: context (causal) + draft (bidirectional within block).""" + bsz = anchor_positions.shape[0] + block_size = self.dflash_block_size q_len = n_blocks * block_size kv_len = seq_len + q_len @@ -615,20 +565,30 @@ def forward( final_mask = (mask_ctx | mask_draft) & valid_block # [B, 1, Q, KV] # Convert bool mask to float additive mask for SDPA - dtype = target_hidden.dtype attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=dtype) attn_mask.masked_fill_(~final_mask, torch.finfo(dtype).min) + return attn_mask - # 7. Draft forward - hidden = self.dflash_module( - noise_embedding=noise_embedding, - target_hidden=target_hidden, - position_ids=full_pos, - attention_mask=attn_mask, - ) + def _compute_loss( + self, logits, input_ids, anchor_positions, block_keep_mask, loss_mask, base_logits=None + ): + """Compute weighted cross-entropy (or KD) loss and accuracy. - # 8. Loss: same-position prediction (position k predicts token at anchor+k) - logits = self._base_model_lm_head(hidden) + Args: + logits: Draft model output [B, N*block_size, vocab]. + input_ids: Original input token IDs [B, seq_len]. + anchor_positions: Anchor positions per block [B, N]. + block_keep_mask: Valid block mask [B, N]. + loss_mask: Token-level loss mask [B, seq_len]. + base_logits: Base model logits for KD loss [B, seq_len, vocab], or None for CE. + + Returns: + (loss, accuracy) tuple. + """ + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + n_blocks = anchor_positions.shape[1] + device = input_ids.device label_offsets = torch.arange(0, block_size, device=device).view(1, 1, -1) label_indices = anchor_positions.unsqueeze(-1) + label_offsets @@ -658,25 +618,20 @@ def forward( decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) weight_mask = weight_mask * decay - # Cross entropy or logit distillation flat_logits = logits.view(-1, logits.size(-1)) flat_targets = target_ids.view(-1) flat_weights = weight_mask.view(-1) - valid_count = flat_weights.sum() + 1e-6 if valid_count > 1.0: - if self.dflash_self_logit_distillation: - # Teacher logits at position p predict token p+1 (autoregressive). - # Draft position k predicts token at anchor+k (same position). - # So teacher logits for token anchor+k are at position anchor+k-1. - base_logits = base_outputs.logits # [B, seq, vocab] + if base_logits is not None: + # KD loss: teacher logits for token anchor+k are at position anchor+k-1 teacher_indices = (safe_label_indices - 1).clamp(min=0) teacher_logits = torch.gather( base_logits.unsqueeze(1).expand(-1, n_blocks, -1, -1), 2, teacher_indices.unsqueeze(-1).expand(-1, -1, -1, base_logits.size(-1)), - ) # [B, N, block_size, vocab] + ) flat_teacher = teacher_logits.reshape(-1, base_logits.size(-1)).detach() target_soft = torch.softmax(flat_teacher, dim=-1) draft_logsoft = torch.log_softmax(flat_logits, dim=-1) @@ -695,6 +650,107 @@ def forward( loss = flat_logits.sum() * 0.0 accuracy = 0.0 + return loss, accuracy + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """Training forward matching SpecForge latest (post-PR #473). + + Key changes from original PR #415: + - Random anchor sampling instead of uniform block division + - Bidirectional intra-block attention (no causal constraint) + - Context sees strictly before anchor position + - Label alignment: position k predicts token at anchor+k + - Optional loss decay weighting + """ + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + bsz, seq_len = input_ids.shape + device = input_ids.device + + # 1. Run base model → hidden states + with torch.no_grad(): + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + + # 2. Build loss mask: combine labels (answer-only) and attention_mask (padding) + loss_mask = torch.ones(bsz, seq_len, device=device) + if labels is not None: + loss_mask = loss_mask * (labels != LabelSmoother.ignore_index).float() + if attention_mask is not None: + loss_mask = loss_mask * attention_mask.float() + + # 3. Random anchor sampling (SpecForge PR #463/#473) + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + if n_blocks == 0 or not block_keep_mask.any(): + # Zero loss that still flows through dflash_module for DDP gradient sync + dummy = self.dflash_module.fc.weight.sum() * 0.0 + return ModelOutput(loss=dummy, logits=base_outputs.logits, train_acc=[[0.0]]) + + # 4-6. Build draft inputs + noise_embedding = self._build_noise_embedding( + input_ids, anchor_positions, block_keep_mask, n_blocks + ) + full_pos = self._build_position_ids(seq_len, anchor_positions, device) + attn_mask = self._build_draft_attention_mask( + seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device + ) + + # 7. Draft forward + hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=full_pos, + attention_mask=attn_mask, + ) + + # 8. Compute loss and accuracy + logits = self._base_model_lm_head(hidden) + loss, accuracy = self._compute_loss( + logits, + input_ids, + anchor_positions, + block_keep_mask, + loss_mask, + base_outputs.logits if self.dflash_self_logit_distillation else None, + ) + return ModelOutput( loss=loss, logits=base_outputs.logits, diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 20c4edc3eb..8117e9a543 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -140,42 +140,43 @@ def test_save_and_restore(self, tmp_path): class TestDFlashMetaRotaryFix: - """Test load_state_dict post-hook fixes meta-tensor rotary buffers. + """Test _apply fixes meta-tensor rotary buffers on first .to() call. During checkpoint restore, rotary inv_freq buffers may be on meta device - (they are computed, not saved). The post-hook should re-create them on CPU. + (they are computed, not saved). _apply should re-create them once. """ - def test_load_state_dict_fixes_meta_rotary(self): - """Test that load_state_dict recreates rotary_emb when buffers are on meta device.""" + def test_to_fixes_meta_rotary(self): + """Test that .to() recreates rotary_emb when buffers are on meta device.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) dflash_mod = model.dflash_module - sd = dflash_mod.state_dict() - # Simulate meta buffers (as happens during checkpoint restore) for name, buf in list(dflash_mod.rotary_emb.named_buffers()): dflash_mod.rotary_emb._buffers[name] = torch.empty_like(buf, device="meta") + dflash_mod._meta_fixed = False # Reset flag to simulate fresh resume assert any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) - # load_state_dict triggers the post-hook which should fix meta buffers - dflash_mod.load_state_dict(sd, strict=False) + # .to() triggers _apply which should fix meta buffers + dflash_mod.to("cpu") assert not any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) + assert dflash_mod._meta_fixed - def test_load_state_dict_noop_when_no_meta(self): - """Test that load_state_dict does not recreate rotary_emb when buffers are normal.""" + def test_to_noop_after_first_fix(self): + """Test that _apply skips check after first fix (one-shot).""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) dflash_mod = model.dflash_module - sd = dflash_mod.state_dict() rotary_id_before = id(dflash_mod.rotary_emb) - dflash_mod.load_state_dict(sd, strict=False) + dflash_mod._meta_fixed = True # Already fixed + dflash_mod.to("cpu") + # Should not recreate rotary_emb assert id(dflash_mod.rotary_emb) == rotary_id_before diff --git a/tools/launcher/common/dflash/export.sh b/tools/launcher/common/dflash/export.sh deleted file mode 100644 index 730073596e..0000000000 --- a/tools/launcher/common/dflash/export.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Export speculative decoding checkpoint to deployment format. -# Auto-detects latest checkpoint and exports via export_hf_checkpoint.py. -# -# Args: -# --model_path Training output dir (auto-detects latest checkpoint) -# --export_path Destination directory - -SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" -source ${SCRIPT_DIR}/../service_utils.sh - -pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt 2>&1 | tail -3 - -trap 'error_handler $0 $LINENO' ERR - -# Auto-detect latest checkpoint -MODEL_PATH="" -EXPORT_PATH="" -while [ $# -gt 0 ]; do - case "$1" in - --model_path) shift; MODEL_PATH="$1" ;; - --export_path) shift; EXPORT_PATH="$1" ;; - *) ;; - esac - shift -done - -# Find latest checkpoint if model_path is a training dir -if [ ! -f "${MODEL_PATH}/model.safetensors" ]; then - LAST_CKPT=$(ls -d ${MODEL_PATH}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) - if [ -n "$LAST_CKPT" ]; then - echo "Using latest checkpoint: $LAST_CKPT" - MODEL_PATH="$LAST_CKPT" - fi -fi - -echo "=== Export ===" -echo "Model: ${MODEL_PATH}" -echo "Export: ${EXPORT_PATH}" - -CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ - --model_path "${MODEL_PATH}" \ - --export_path "${EXPORT_PATH}" - -echo "Export contents:" -ls -lh ${EXPORT_PATH}/ - -report_result "PASS: Export" diff --git a/tools/launcher/common/dflash/ptq_and_export.sh b/tools/launcher/common/dflash/ptq_and_export.sh deleted file mode 100644 index ece933d686..0000000000 --- a/tools/launcher/common/dflash/ptq_and_export.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# PTQ + export for speculative decoding checkpoints (EAGLE3, DFlash). -# Uses hf_ptq.py to quantize and export in one step. -# -# Args are passed directly to hf_ptq.py. - -SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" -source ${SCRIPT_DIR}/../service_utils.sh - -pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt 2>&1 | tail -3 - -trap 'error_handler $0 $LINENO' ERR - -# Find latest checkpoint if model_dir points to a training output dir -MODEL_DIR="" -ARGS=() -while [ $# -gt 0 ]; do - case "$1" in - --model_dir) - shift - MODEL_DIR="$1" - # Auto-detect latest checkpoint - LAST_CKPT=$(ls -d ${MODEL_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) - if [ -f "${MODEL_DIR}/model.safetensors" ]; then - ARGS+=("--model_dir" "$MODEL_DIR") - elif [ -n "$LAST_CKPT" ]; then - echo "Using latest checkpoint: $LAST_CKPT" - ARGS+=("--model_dir" "$LAST_CKPT") - else - ARGS+=("--model_dir" "$MODEL_DIR") - fi - ;; - *) ARGS+=("$1") ;; - esac - shift -done - -echo "=== PTQ + Export ===" -echo "Args: ${ARGS[*]}" - -CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/llm_ptq/hf_ptq.py \ - "${ARGS[@]}" - -report_result "PASS: PTQ + Export" diff --git a/tools/launcher/common/dflash/vllm_serve.sh b/tools/launcher/common/dflash/vllm_serve.sh deleted file mode 100644 index a7e5857f09..0000000000 --- a/tools/launcher/common/dflash/vllm_serve.sh +++ /dev/null @@ -1,144 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Launch vLLM server with DFlash speculative decoding, run benchmark, then shut down. -# -# Required env vars: -# HF_MODEL_CKPT — target model path -# DRAFT_MODEL — DFlash draft model path -# -# Optional env vars: -# NUM_SPEC_TOKENS — number of speculative tokens (default: 15) -# VLLM_PORT — server port (default: 8000) -# MAX_BATCHED_TOKENS — max batched tokens (default: 32768) -# BENCHMARK_PROMPTS — path to benchmark prompts file - -SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" -source ${SCRIPT_DIR}/../service_utils.sh 2>/dev/null || true - -cleanup() { kill $SERVER_PID 2>/dev/null; sleep 2; kill -9 $SERVER_PID 2>/dev/null; } -trap cleanup EXIT - -MODEL=${HF_MODEL_CKPT} -DRAFT=${DRAFT_MODEL} -NUM_SPEC=${NUM_SPEC_TOKENS:-15} -PORT=${VLLM_PORT:-8000} -MAX_TOKENS=${MAX_BATCHED_TOKENS:-32768} -TP=${TP_SIZE:-1} - -echo "=== vLLM DFlash Speculative Decoding ===" -echo "Target: ${MODEL}" -echo "Draft: ${DRAFT}" -echo "Spec tokens: ${NUM_SPEC}, TP: ${TP}" - -# Start vLLM server in background -vllm serve ${MODEL} \ - --speculative-config "{\"method\": \"dflash\", \"model\": \"${DRAFT}\", \"num_speculative_tokens\": ${NUM_SPEC}}" \ - --max-num-batched-tokens ${MAX_TOKENS} \ - --tensor-parallel-size ${TP} \ - --port ${PORT} \ - & -SERVER_PID=$! - -# Wait for server to be ready -echo "Waiting for vLLM server to start..." -for i in $(seq 1 120); do - if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then - echo "Server ready after ${i}s" - break - fi - if ! kill -0 $SERVER_PID 2>/dev/null; then - echo "ERROR: Server process died" - wait $SERVER_PID - exit 1 - fi - sleep 1 -done - -if ! curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then - echo "ERROR: Server failed to start within 120s" - kill $SERVER_PID 2>/dev/null - exit 1 -fi - -# Run a quick test -echo "" -echo "=== Quick generation test ===" -curl -s http://localhost:${PORT}/v1/completions \ - -H "Content-Type: application/json" \ - -d "{ - \"model\": \"${MODEL}\", - \"prompt\": \"What is 2+3?\", - \"max_tokens\": 64, - \"temperature\": 0 - }" | python3 -c "import json,sys; r=json.load(sys.stdin); print(r.get('choices',[{}])[0].get('text','ERROR')[:200]); print(f'Usage: {r.get(\"usage\",{})}')" - -# Run benchmark if prompts file provided -if [ -n "${BENCHMARK_PROMPTS}" ] && [ -f "${BENCHMARK_PROMPTS}" ]; then - echo "" - echo "=== MT-Bench Benchmark ===" - python3 -c " -import json, time, requests -from collections import defaultdict - -with open('${BENCHMARK_PROMPTS}') as f: - prompts = [json.loads(line) for line in f][:80] - -url = 'http://localhost:${PORT}/v1/completions' -cat_results = defaultdict(lambda: {'tokens': [], 'times': []}) - -for i, p in enumerate(prompts): - q = p.get('prompt', p.get('turns', [p.get('question', 'Hello')]))[0] if isinstance(p, dict) else str(p) - cat = p.get('category', 'unknown') if isinstance(p, dict) else 'unknown' - start = time.time() - r = requests.post(url, json={ - 'model': '${MODEL}', - 'prompt': q, - 'max_tokens': 1024, - 'temperature': 0, - }).json() - elapsed = time.time() - start - n = r.get('usage', {}).get('completion_tokens', 0) - cat_results[cat]['tokens'].append(n) - cat_results[cat]['times'].append(elapsed) - tps = n / elapsed if elapsed > 0 else 0 - print(f' [{i+1}/{len(prompts)}] [{cat}] {n} tokens in {elapsed:.1f}s = {tps:.1f} tok/s') - -print(f'\n=== Per-Category Results ===') -print(f'{\"Category\":>12} | {\"Prompts\":>7} | {\"Tokens\":>8} | {\"Time(s)\":>8} | {\"TPS\":>8}') -print('-' * 55) -all_tokens = 0 -all_time = 0 -for cat in sorted(cat_results): - t = sum(cat_results[cat]['tokens']) - s = sum(cat_results[cat]['times']) - n = len(cat_results[cat]['tokens']) - tps = t / s if s > 0 else 0 - all_tokens += t - all_time += s - print(f'{cat:>12} | {n:>7} | {t:>8} | {s:>8.1f} | {tps:>8.1f}') -print('-' * 55) -print(f'{\"ALL\":>12} | {sum(len(v[\"tokens\"]) for v in cat_results.values()):>7} | {all_tokens:>8} | {all_time:>8.1f} | {all_tokens/all_time:>8.1f}') -" -fi - -# Shut down server -echo "" -echo "Shutting down server..." -kill $SERVER_PID 2>/dev/null -wait $SERVER_PID 2>/dev/null || true - -echo "Done" diff --git a/tools/launcher/common/dflash/ar_eval_mtbench.sh b/tools/launcher/common/specdec/ar_eval_mtbench.sh similarity index 100% rename from tools/launcher/common/dflash/ar_eval_mtbench.sh rename to tools/launcher/common/specdec/ar_eval_mtbench.sh diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/specdec/dflash_online_training.sh similarity index 77% rename from tools/launcher/common/dflash/online_training.sh rename to tools/launcher/common/specdec/dflash_online_training.sh index 2bce9c9ded..8694cedf0b 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/specdec/dflash_online_training.sh @@ -96,3 +96,31 @@ set -x start_time=$(date +%s) accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS $MAIN_PY "$@" echo "Training time: $(( $(date +%s) - start_time )) seconds" +set +x + +# Export last checkpoint to deployment format (rank 0 only, single GPU) +if [ "${SLURM_PROCID:-0}" = "0" ]; then + OUTPUT_DIR=$(python3 -c " +import sys +for arg in sys.argv[1:]: + if arg.startswith('training.output_dir='): + print(arg.split('=', 1)[1]) + break +" "$@") + + if [ -n "$OUTPUT_DIR" ]; then + LAST_CKPT=$(ls -d ${OUTPUT_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) + if [ -n "$LAST_CKPT" ]; then + STEP=$(basename "$LAST_CKPT" | sed 's/checkpoint-//') + EXPORT_DIR="${OUTPUT_DIR}/exported-checkpoint-${STEP}" + echo "=== Exporting last checkpoint: ${LAST_CKPT} → ${EXPORT_DIR} ===" + CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ + --model_path "${LAST_CKPT}" \ + --export_path "${EXPORT_DIR}" + echo "Export contents:" + ls -lh "${EXPORT_DIR}/" + else + echo "No checkpoints found in ${OUTPUT_DIR}, skipping export" + fi + fi +fi diff --git a/tools/launcher/common/specdec/vllm_smoke_test.sh b/tools/launcher/common/specdec/vllm_smoke_test.sh new file mode 100644 index 0000000000..ccd59cf094 --- /dev/null +++ b/tools/launcher/common/specdec/vllm_smoke_test.sh @@ -0,0 +1,110 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Quick vLLM smoke test for speculative decoding (EAGLE3, DFlash, etc.). +# Launches server, sends a few test prompts, verifies responses, and shuts down. +# +# Required env vars: +# HF_MODEL_CKPT — target model path +# DRAFT_MODEL — draft model path +# +# Optional env vars: +# SPEC_METHOD — speculative method: "eagle", "dflash", etc. (default: "eagle") +# NUM_SPEC_TOKENS — number of speculative tokens (default: 15) +# TP_SIZE — tensor parallel size (default: 1) +# VLLM_PORT — server port (default: 8000) + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh 2>/dev/null || true + +cleanup() { kill $SERVER_PID 2>/dev/null; sleep 2; kill -9 $SERVER_PID 2>/dev/null; } +trap cleanup EXIT + +MODEL=${HF_MODEL_CKPT} +DRAFT=${DRAFT_MODEL} +METHOD=${SPEC_METHOD:-eagle} +NUM_SPEC=${NUM_SPEC_TOKENS:-15} +PORT=${VLLM_PORT:-8000} +TP=${TP_SIZE:-1} + +echo "=== vLLM Speculative Decoding Smoke Test ===" +echo "Method: ${METHOD}" +echo "Target: ${MODEL}" +echo "Draft: ${DRAFT}" +echo "Spec tokens: ${NUM_SPEC}, TP: ${TP}" + +# Build speculative config +SPEC_CONFIG="{\"method\": \"${METHOD}\", \"model\": \"${DRAFT}\", \"num_speculative_tokens\": ${NUM_SPEC}}" + +# Start vLLM server +vllm serve ${MODEL} \ + --speculative-config "${SPEC_CONFIG}" \ + --max-num-batched-tokens 32768 \ + --tensor-parallel-size ${TP} \ + --port ${PORT} \ + & +SERVER_PID=$! + +# Wait for server +echo "Waiting for vLLM server..." +for i in $(seq 1 180); do + if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then + echo "Server ready after ${i}s" + break + fi + if ! kill -0 $SERVER_PID 2>/dev/null; then + echo "ERROR: Server died"; wait $SERVER_PID; exit 1 + fi + sleep 1 +done + +if ! curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then + echo "ERROR: Server timeout"; exit 1 +fi + +# Run quick test prompts +echo "" +echo "=== Test Prompts ===" +PASS=0 +FAIL=0 +for PROMPT in \ + "What is 2+3? Answer with just the number." \ + "Write a haiku about mountains." \ + "Explain what a CPU is in one sentence."; do + RESPONSE=$(curl -s http://localhost:${PORT}/v1/completions \ + -H "Content-Type: application/json" \ + -d "{\"model\": \"${MODEL}\", \"prompt\": \"${PROMPT}\", \"max_tokens\": 64, \"temperature\": 0}" \ + | python3 -c "import json,sys; r=json.load(sys.stdin); t=r.get('choices',[{}])[0].get('text',''); u=r.get('usage',{}); print(f'{t.strip()[:100]}|||{u.get(\"completion_tokens\",0)}')" 2>/dev/null) + TEXT=$(echo "$RESPONSE" | cut -d'|||' -f1) + TOKENS=$(echo "$RESPONSE" | cut -d'|||' -f2) + if [ -n "$TEXT" ] && [ "$TOKENS" -gt 0 ] 2>/dev/null; then + echo " PASS: \"${PROMPT}\" → ${TOKENS} tokens" + PASS=$((PASS + 1)) + else + echo " FAIL: \"${PROMPT}\" → empty or error" + FAIL=$((FAIL + 1)) + fi +done + +echo "" +echo "Results: ${PASS} passed, ${FAIL} failed" + +if [ $FAIL -gt 0 ]; then + echo "ERROR: Some prompts failed" + exit 1 +fi + +echo "Done" diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml index f3bb543bde..a6e44af244 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -17,7 +17,7 @@ pipeline: # Step 1: Online DFlash training task_0: - script: common/dflash/online_training.sh + script: common/specdec/dflash_online_training.sh args: - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml - model.model_name_or_path=<> @@ -43,7 +43,7 @@ pipeline: # Step 2: MT-Bench per-category AR evaluation task_1: - script: common/dflash/ar_eval_mtbench.sh + script: common/specdec/ar_eval_mtbench.sh args: - --ckpt_dir /scratchspace/dflash_bs16 - --osl 512 From 002b21349f0985937175068d6555df86547d3971 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 19:25:54 -0700 Subject: [PATCH 23/24] address remaining PR review comments - Rename make_eagle_supervised_data_module -> make_speculative_data_module - Remove redundant DFlashExporter.__init__ (just called super) - Merge duplicate estimate_ar if-branches in eagle_utils.py - Add answer_only_loss / chat template limitation note in dflash.md - Add TODO for epoch-seeded anchor sampling to enable mask caching - Add TODO for co-training (remove no_grad/eval) - Validate HEAD_NODE_IP in multi-node training script - Fix markdown lint: add language tags to fenced code blocks Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/doc/dflash.md | 28 ++++++++++++++----- examples/speculative_decoding/eagle_utils.py | 14 ++++------ examples/speculative_decoding/main.py | 4 +-- .../torch/export/plugins/hf_spec_export.py | 8 ++---- .../torch/speculative/plugins/hf_dflash.py | 8 ++++-- .../common/specdec/dflash_online_training.sh | 4 +++ 6 files changed, 41 insertions(+), 25 deletions(-) diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 79cb6ca461..c31e5abff0 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -9,7 +9,7 @@ Reference: [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) | ## Architecture -``` +```text Target Model (frozen) │ ├─ hidden_states[layer 1, 9, 17, 25, 33] ──► concat ──► FC + RMSNorm ──► target_hidden @@ -43,7 +43,7 @@ Target Model (frozen) Given context `"The answer is"` and block_size=4 with anchor `"is"`: -``` +```text Target model hidden states (from frozen base model): h["The"] h["answer"] h["is"] ← target_hidden (ctx_len=3) │ │ │ @@ -87,7 +87,7 @@ In each DFlash decoder layer: **Training vs Inference:** -``` +```text TRAINING (2 anchors, block_size=4): Context tokens: "The" "answer" "is" "5" "." @@ -166,13 +166,25 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | | `training.answer_only_loss` | false | Mask loss on non-assistant tokens | +> **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the +> dataset loader replaces the tokenizer's chat template with a simplified version that has +> `{% generation %}` tags to identify assistant turns. This simplified template may not +> support all features of the original (e.g., tool use formatting, multi-turn system +> prompts). During serving, the draft model reuses the target model's original tokenizer +> and template, so there is no train/inference mismatch in the tokenization itself — only +> the loss masking during training uses the simplified template. However, if training data +> contains tool-use conversations with model-family-specific formatting, the simplified +> template may tokenize them differently, affecting which tokens get masked. For best +> results with tool-use data, set `answer_only_loss=false` or provide a custom +> `chat_template` that supports both generation tags and tool-use formatting. + ### Random Anchor Sampling (`num_anchors`) During training, anchor positions are sampled randomly from valid (assistant response) tokens in each batch, rather than dividing the sequence into fixed blocks. Each anchor starts a block of `block_size` tokens where the draft model predicts positions 1..B-1. -``` +```text Sequence: [SYS] You helpful [USR] What 2+3? [AST] The answer is 5 Position: 0 1 2 3 4 5 6 7 8 9 10 loss_mask: 0 0 0 0 0 0 0 1 1 1 1 @@ -208,7 +220,7 @@ The exponential decay factor (gamma) weights early block positions higher than l If position 1 in a block is wrong, all subsequent positions are rejected in speculative decoding. Decay aligns the training loss with what matters for acceptance rate. -``` +```text weight[k] = exp(-(k-1).clamp(min=0) / gamma) for k = 0..B-1 ``` @@ -324,8 +336,8 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories. - **FP8 / NVFP4 quantization**: Export pipeline supports quantized checkpoints via `hf_ptq.py` (PTQ succeeded in testing). AR impact of quantization not yet measured. The flow: train (bf16) → `mtq.quantize(model, quant_cfg)` → `export_hf_checkpoint.py`. -- **Checkpoint resume**: `DFlashModule._apply()` handles meta-tensor rotary buffers. - Validated in training runs but not covered by integration tests. +- **Checkpoint resume**: `DFlashModule._apply()` handles meta-tensor rotary buffers + (one-shot check on first `.to(device)` call). Validated in train+resume E2E tests. ### Validated @@ -334,10 +346,12 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories. - **AR evaluation**: `ar_validate.py` with online GT, per-category MT-Bench. - **vLLM deployment**: Speculative decoding with `vllm/vllm-openai:nightly` (v0.19.1+). 3.1x speedup over baseline. Per-category benchmarks on MT-Bench. + ```bash vllm serve Qwen/Qwen3-8B \ --speculative-config '{"method": "dflash", "model": "path/to/checkpoint", "num_speculative_tokens": 7}' \ --max-num-batched-tokens 32768 ``` + - **Export**: z-lab compatible HF format, loadable by vLLM and z-lab benchmark. - **Loss decay**: Validated +0.12 AR improvement with gamma=7 (bs16). diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index fed88401a0..2b08ec8096 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -137,7 +137,7 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: return batch -def make_eagle_supervised_data_module( +def make_speculative_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, train_len=None, @@ -213,6 +213,11 @@ def on_log(self, args, state, control, **kwargs): print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]") except Exception: print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}") + # Log accuracy to HF Trainer's logs dict (picked up by TensorBoard) + logs = kwargs.get("logs") or {} + for i, draft_acc in enumerate(average_acc): + for j, step_acc in enumerate(draft_acc): + logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc) if self.estimate_ar: # Calculate mean training AR since last log # NOTE: This is only an estimate of the real AR. @@ -226,13 +231,6 @@ def on_log(self, args, state, control, **kwargs): acc_cumprod *= draft_acc[-1] est_ar += acc_cumprod print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") - - # Log accuracy to HF Trainer's logs dict (picked up by TensorBoard) - logs = kwargs.get("logs") or {} - for i, draft_acc in enumerate(average_acc): - for j, step_acc in enumerate(draft_acc): - logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc) - if self.estimate_ar: logs["estimated_training_ar"] = est_ar # log to wandb diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 047a70d864..5cee98fb51 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -40,7 +40,7 @@ from eagle_utils import ( EagleTrainerWithAccLog, EagleTrainingPlot, - make_eagle_supervised_data_module, + make_speculative_data_module, patch_ring_attention_for_ttt, ) from omegaconf import OmegaConf @@ -263,7 +263,7 @@ def train(): print_rank_0("Loading dataset...") if training_args.mode in ("eagle3", "dflash"): - data_module = make_eagle_supervised_data_module( + data_module = make_speculative_data_module( tokenizer, data_args, train_len=training_args.training_seq_len, diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index f298d86be1..82adea89df 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -253,10 +253,6 @@ class DFlashExporter(SpeculativeDecodingExporter): - config.json: Qwen3-style config with dflash_config field """ - def __init__(self, model: nn.Module): - """Initialize the DFlashExporter.""" - super().__init__(model) - def _extract_state_dict(self, full_state_dict: dict): """Extract DFlash module weights, stripping the dflash_module prefix.""" export_sd = {} @@ -316,7 +312,9 @@ def _export_config(self): ), "rope_scaling": getattr(base_config, "rope_scaling", None), "tie_word_embeddings": False, - "torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace("torch.", ""), + "torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace( + "torch.", "" + ), "num_target_layers": getattr(base_config, "num_hidden_layers", 36), } diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index d5df9cdb36..8f175cc0c9 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -472,9 +472,10 @@ def _sample_anchor_positions(self, seq_len, loss_mask, device): Returns (anchor_positions [B, N], block_keep_mask [B, N]). - Note: Fixed (uniform) anchors would allow caching masks/positions across steps, - but random anchors provide data augmentation during training. A fixed-anchor mode - could be added as a future optimization for inference or fine-tuning. + TODO: Fix the random seed per epoch (change between epochs) so that anchor + positions are deterministic within an epoch. This would allow caching the derived + masks and position IDs across steps while preserving the same data augmentation + effect. Currently, anchors are re-sampled every forward pass. """ bs = self.dflash_block_size bsz = loss_mask.shape[0] @@ -694,6 +695,7 @@ def forward( device = input_ids.device # 1. Run base model → hidden states + # TODO: For co-training the base model, remove no_grad and eval() switch. with torch.no_grad(): base_outputs = super().forward( input_ids=input_ids, diff --git a/tools/launcher/common/specdec/dflash_online_training.sh b/tools/launcher/common/specdec/dflash_online_training.sh index 8694cedf0b..654ba0185d 100644 --- a/tools/launcher/common/specdec/dflash_online_training.sh +++ b/tools/launcher/common/specdec/dflash_online_training.sh @@ -75,6 +75,10 @@ fi MAIN_PY=modules/Model-Optimizer/examples/speculative_decoding/main.py if [[ "$NUM_NODES" != "1" ]]; then + if [ -z "$HEAD_NODE_IP" ]; then + echo "ERROR: HEAD_NODE_IP is empty. Cannot launch multi-node training." + exit 1 + fi GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" From a04dc7aaf92fc4781f63bb4f1dd8a4a0500c1308 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 9 Apr 2026 21:22:33 -0700 Subject: [PATCH 24/24] fix: always check for meta rotary buffers in _apply The one-shot _meta_fixed flag could be set prematurely when from_pretrained calls .to() during weight loading. Remove the flag and always check -- the cost is negligible. Also remove ShareGPT format support from dataset loader -- only accept OpenAI messages format. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 13 ++-- .../utils/plugins/transformers_dataset.py | 61 ++++--------------- .../speculative/plugins/test_hf_dflash.py | 12 ++-- 3 files changed, 21 insertions(+), 65 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 8f175cc0c9..b7c0fa91f6 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -251,19 +251,14 @@ def __init__(self, config): # automatic _init_weights walk doesn't reach these new layers. self._init_weights(config) - self._meta_fixed = False - def _apply(self, fn, recurse=True): - """Fix meta-tensor rotary buffers on first .to(device) call, then no-op. + """Fix meta-tensor rotary buffers before device transfer. On checkpoint resume, inv_freq (a computed buffer, not saved in checkpoint) - stays on meta device. Re-create rotary_emb on CPU before the first device - transfer so .to(device) can proceed. + stays on meta device. Re-create rotary_emb on CPU so .to(device) can proceed. """ - if not self._meta_fixed and hasattr(self, "rotary_emb"): - if any(b.is_meta for b in self.rotary_emb.buffers()): - self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device="cpu") - self._meta_fixed = True + if hasattr(self, "rotary_emb") and any(b.is_meta for b in self.rotary_emb.buffers()): + self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device="cpu") return super()._apply(fn, recurse) def _init_weights(self, config): diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index b9a5367cd9..cda1c02ccc 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -33,26 +33,6 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index -def _sharegpt_to_openai_messages(conversations: list[dict]): - """Optionally align sharedgpt format to openai format.""" - role_mapping = { - "user": "user", - "User": "user", - "human": "user", - "assistant": "assistant", - "Assistant": "assistant", - "gpt": "assistant", - "system": "system", - "System": "system", - } - messages = [] - for msg in conversations: - role = role_mapping[msg["role"]] - content = msg["content"] - messages.append({"role": role, "content": content}) - return messages - - class ShardedDataset(torch.utils.data.Dataset): """Subclass of torch.utils.data.Dataset to load data from HuggingFace dataset.""" @@ -388,29 +368,17 @@ def __call__(self, examples): batch.append(text) else: messages = example.get("messages", None) - conversations = example.get("conversations", None) - # Prefer whichever has an assistant turn for training - if messages and any(m.get("role") == "assistant" for m in messages): - batch.append(messages) - elif conversations: - converted = _sharegpt_to_openai_messages(conversations) - if not any(m.get("role") == "assistant" for m in converted): - print_rank_0( - "=== WARNING === Skipping sample with no assistant turn in conversations." - ) - continue - batch.append(converted) - elif messages: - if not any(m.get("role") == "assistant" for m in messages): - print_rank_0( - "=== WARNING === Skipping sample with no assistant turn in messages." - ) - continue - batch.append(messages) - else: + if not messages: raise ValueError( - "The sample must in either OpenAI messages format or ShareGPT conversations format." + "Sample must have a 'messages' field in OpenAI format " + "(list of {role, content} dicts)." ) + if not any(m.get("role") == "assistant" for m in messages): + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in messages." + ) + continue + batch.append(messages) if not batch: # All samples skipped — create a dummy batch with all-masked labels @@ -469,13 +437,10 @@ def __call__(self, examples): for example in examples: messages = example.get("messages", None) if messages is None: - conversations = example.get("conversations", None) - if conversations is None: - raise ValueError( - "The sample must in either OpenAI messages format or ShareGPT conversations format." - ) - else: - messages = _sharegpt_to_openai_messages(conversations) + raise ValueError( + "Sample must have a 'messages' field in OpenAI format " + "(list of {role, content} dicts)." + ) copy_messages = copy.deepcopy(messages) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 8117e9a543..8e8c846583 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -140,10 +140,10 @@ def test_save_and_restore(self, tmp_path): class TestDFlashMetaRotaryFix: - """Test _apply fixes meta-tensor rotary buffers on first .to() call. + """Test _apply fixes meta-tensor rotary buffers on .to() calls. During checkpoint restore, rotary inv_freq buffers may be on meta device - (they are computed, not saved). _apply should re-create them once. + (they are computed, not saved). _apply should re-create them. """ def test_to_fixes_meta_rotary(self): @@ -156,7 +156,6 @@ def test_to_fixes_meta_rotary(self): # Simulate meta buffers (as happens during checkpoint restore) for name, buf in list(dflash_mod.rotary_emb.named_buffers()): dflash_mod.rotary_emb._buffers[name] = torch.empty_like(buf, device="meta") - dflash_mod._meta_fixed = False # Reset flag to simulate fresh resume assert any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) @@ -164,19 +163,16 @@ def test_to_fixes_meta_rotary(self): dflash_mod.to("cpu") assert not any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) - assert dflash_mod._meta_fixed - def test_to_noop_after_first_fix(self): - """Test that _apply skips check after first fix (one-shot).""" + def test_to_noop_when_no_meta(self): + """Test that .to() does not recreate rotary_emb when buffers are normal.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) dflash_mod = model.dflash_module rotary_id_before = id(dflash_mod.rotary_emb) - dflash_mod._meta_fixed = True # Already fixed dflash_mod.to("cpu") - # Should not recreate rotary_emb assert id(dflash_mod.rotary_emb) == rotary_id_before