diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 820b587b91..673b2db0e0 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -350,3 +350,46 @@ More models coming soon! - 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) - 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) - ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) + +## DFlash (Block Diffusion for Speculative Decoding) + +DFlash is a parallel speculative decoding method based on [Block Diffusion](https://arxiv.org/abs/2602.06036). +Unlike autoregressive draft models (EAGLE3), DFlash predicts an entire block of tokens in a single forward pass +using masked parallel prediction with KV injection from the target model's hidden states. + +### Quick Start + +For a complete end-to-end example (training + evaluation), see the +[launcher example](../../tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml): + +```bash +uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +``` + +### Key Configuration ([dflash.yaml](../../modelopt_recipes/general/speculative_decoding/dflash.yaml)) + +| Field | Default | Description | +|-------|---------|-------------| +| `dflash.dflash_block_size` | 8 | Block size for parallel prediction | +| `dflash.dflash_num_anchors` | 512 | Number of anchor positions per sample | +| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables) | +| `dflash.dflash_self_logit_distillation` | true | Use logit distillation from target | +| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | +| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | +| `training.answer_only_loss` | false | Mask loss on non-assistant tokens | + +Qwen3 sliding window attention is automatically supported — draft layers inherit +`layer_types` and `sliding_window` from the config, matching the target model's +attention pattern. + +### Export + +```bash +python scripts/export_hf_checkpoint.py \ + --model_path /path/to/training/output \ + --export_path /path/to/exported/model +``` + +### Results + +See [doc/dflash.md](doc/dflash.md) for design details, benchmark results, and open items. diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md new file mode 100644 index 0000000000..c31e5abff0 --- /dev/null +++ b/examples/speculative_decoding/doc/dflash.md @@ -0,0 +1,357 @@ +# DFlash — Block Diffusion for Speculative Decoding + +DFlash predicts an entire block of tokens in a single forward pass using masked parallel +prediction with KV injection from the target model's hidden states. + +Reference: [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) | +[SpecForge](https://github.com/sgl-project/SpecForge) | +[z-lab](https://github.com/z-lab/dflash) + +## Architecture + +```text +Target Model (frozen) + │ + ├─ hidden_states[layer 1, 9, 17, 25, 33] ──► concat ──► FC + RMSNorm ──► target_hidden + │ │ + │ K/V injection + │ │ + └─ embed([anchor, mask, mask, ...]) ──► noise_embedding ──► DFlash Decoder (5 layers) + │ + lm_head ──► draft tokens +``` + +**Key components:** +- **Feature Fusion**: Multi-layer hidden states → Linear(num_layers × hidden_size, hidden_size) + RMSNorm. + Unlike EAGLE3 which uses a single layer's hidden state, DFlash concatenates hidden states from + multiple target layers (e.g., layers 1, 9, 17, 25, 33) to give the draft model richer context. +- **KV Injection**: In each draft decoder layer, K and V are projected from the concatenation of + target hidden states and the block's own embeddings. Q is projected from the block embeddings only + (the `[anchor, mask, mask, ...]` token embeddings). This lets the draft model attend to the + full target context while generating all block positions in parallel. +- **Parallel Drafting**: Position 0 is the anchor (the last accepted token — known and correct), + positions 1..B-1 are filled with a special mask token (similar to BERT's `[MASK]`). The draft + model predicts all B-1 unknown positions in a single forward pass, unlike autoregressive drafters + (EAGLE3) which predict one token at a time. Benefit: one forward pass produces B-1 draft tokens. +- **Random Anchor Sampling**: During training, anchors are sampled randomly from assistant response + positions (where `loss_mask=1`), not placed at fixed intervals. The anchor is the starting token + of each training block — it's always correct (from the ground truth) and the model learns to + predict the next B-1 tokens given this anchor and the target's hidden states. See the + [illustrated example](#random-anchor-sampling-num_anchors) below for why this improves efficiency. + +**KV Injection (token-level example):** + +Given context `"The answer is"` and block_size=4 with anchor `"is"`: + +```text +Target model hidden states (from frozen base model): + h["The"] h["answer"] h["is"] ← target_hidden (ctx_len=3) + │ │ │ + └──── FC + RMSNorm ────┘ + │ + fused context features + +Block input (draft token embeddings): + embed("is") embed(MASK) embed(MASK) embed(MASK) ← noise_embedding (block_size=4) + pos=3 pos=4 pos=5 pos=6 + +In each DFlash decoder layer: + Q = q_proj(noise_embedding) ← shape [4, head_dim] + only the block tokens generate queries + + K = concat( ← shape [7, head_dim] + k_proj(fused_context), ← from target hidden [3 positions: "The","answer","is"] + k_proj(noise_embedding) ← from block tokens [4 positions: "is",MASK,MASK,MASK] + ) + + V = concat(v_proj(fused_context), v_proj(noise_embedding)) ← same shape as K + + Attention: Q (4 tokens) attends to K/V (7 tokens) + + K/V: "The" "answer" "is" │ "is" MASK MASK MASK + pos0 pos1 pos2 │ pos3 pos4 pos5 pos6 + ───────────────────────────┼────────────────────────── + Q pos=3 "is" : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + Q pos=4 MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + Q pos=5 MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + Q pos=6 MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + ─── context ─── │ ──── block ──────────── + (bidirectional within block, no attention mask at inference) + + Output → lm_head → predictions: + pos=3: skip (anchor, already known) + pos=4: predict token after "is" → "5" + pos=5: predict token after "is 5" → "." + pos=6: predict token after "is 5." → "[EOS]" +``` + +**Training vs Inference:** + +```text +TRAINING (2 anchors, block_size=4): + + Context tokens: "The" "answer" "is" "5" "." + Block 0 (anchor="The"): [The, MASK, MASK, MASK] + Block 1 (anchor="is"): [is, MASK, MASK, MASK] + + All blocks processed in ONE forward pass. Attention mask controls visibility: + + K/V (context) K/V (block 0) K/V (block 1) + "The" "ans" "is" "5" "." The M M M is M M M + c0 c1 c2 c3 c4 b0 b1 b2 b3 b4 b5 b6 b7 + Q ───────────────────────────────────────────────────────────────────────── + b0 "The" : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b1 MASK : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b2 MASK : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b3 MASK : ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗ + b4 "is" : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + b5 MASK : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + b6 MASK : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + b7 MASK : ✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗ ✓ ✓ ✓ ✓ + ── context ────── ── block 0 ────── ── block 1 ────── + + Block 0: first block sees NO context (✗), only its own block (bidirectional ✓) + Block 1: sees context before anchor "is" (c0,c1 ✓), NOT its own anchor or later + plus its own block (bidirectional ✓) + + Loss: computed on all non-anchor positions simultaneously. + No verification — ground truth labels known from training data. + +INFERENCE (one block at a time, NO attention mask): + + Step 1: target forward("The answer is") → base_token = "5" + block = [5, MASK, MASK, MASK] + + K/V: "The" "ans" "is" │ "5" MASK MASK MASK + Q ─────────────────────────────────┼────────────────────────── + "5" : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + MASK : ✓ ✓ ✓ │ ✓ ✓ ✓ ✓ + + All ✓ — no mask at inference. Block sees full context freely. + Target verifies → accept 3 → sequence: "The answer is 5 . [EOS]" + + Step 2: next block with grown context (5 tokens) ... +``` + +The draft model sees the target's internal representation of the context (via KV injection) +without re-running the target model for drafting. The expensive target forward pass is +only needed for verification — the lightweight draft model reuses the target's hidden states. + +**Draft model components** (Qwen3-based): +- `Qwen3MLP`, `Qwen3RMSNorm`, `Qwen3RotaryEmbedding` from transformers +- Sliding window attention supported via `config.layer_types` *(implemented, not yet validated end-to-end)* +- Independent of target model architecture + +## Training + +### Quick Start + +```bash +uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +``` + +### Recipe + +See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../modelopt_recipes/general/speculative_decoding/dflash.yaml) + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `dflash.dflash_block_size` | 8 | Block size for parallel prediction | +| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | +| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | +| `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | +| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | +| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions | +| `training.answer_only_loss` | false | Mask loss on non-assistant tokens | + +> **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the +> dataset loader replaces the tokenizer's chat template with a simplified version that has +> `{% generation %}` tags to identify assistant turns. This simplified template may not +> support all features of the original (e.g., tool use formatting, multi-turn system +> prompts). During serving, the draft model reuses the target model's original tokenizer +> and template, so there is no train/inference mismatch in the tokenization itself — only +> the loss masking during training uses the simplified template. However, if training data +> contains tool-use conversations with model-family-specific formatting, the simplified +> template may tokenize them differently, affecting which tokens get masked. For best +> results with tool-use data, set `answer_only_loss=false` or provide a custom +> `chat_template` that supports both generation tags and tool-use formatting. + +### Random Anchor Sampling (`num_anchors`) + +During training, anchor positions are sampled randomly from valid (assistant response) +tokens in each batch, rather than dividing the sequence into fixed blocks. Each anchor +starts a block of `block_size` tokens where the draft model predicts positions 1..B-1. + +```text +Sequence: [SYS] You helpful [USR] What 2+3? [AST] The answer is 5 +Position: 0 1 2 3 4 5 6 7 8 9 10 +loss_mask: 0 0 0 0 0 0 0 1 1 1 1 + ^^^^^^^^^^^^^^^^ + assistant response + +Fixed blocks (block_size=4): +Block 0: pos [0,1,2,3] anchor=0 → predict 1,2,3 → loss_mask=0,0,0 → ZERO LOSS +Block 1: pos [4,5,6,7] anchor=4 → predict 5,6,7 → loss_mask=0,0,1 → 1/3 useful +Block 2: pos [8,9,10,—] anchor=8 → predict 9,10,— → loss_mask=1,1,— → 2/2 useful + +Efficiency: 3/8 = 38% + +Random anchors (num_anchors=3, sampled from loss_mask=1): +Anchor 7: pos [7,8,9,10] → predict 8,9,10 → loss_mask=1,1,1 → 3/3 useful +Anchor 9: pos [9,10,—,—] → predict 10,—,— → loss_mask=1,—,— → 1/1 useful +Anchor 8: pos [8,9,10,—] → predict 9,10,— → loss_mask=1,1,— → 2/2 useful + +Efficiency: 6/6 = 100% +``` + +Random anchors guarantee every prediction is on assistant tokens. +Fixed blocks waste compute on prompt tokens where loss_mask=0. + +**Tradeoff:** Higher `num_anchors` = more training signal per sample but more compute. +Lower = faster iteration but less data efficiency. With `seq_len=4096` and `block_size=8`, +`num_anchors=512` means the model sees ~512 blocks per sample (covering ~4096 positions). +Scale proportionally: `num_anchors ≈ seq_len / block_size` gives full coverage. + +### Loss Decay + +The exponential decay factor (gamma) weights early block positions higher than later ones. +If position 1 in a block is wrong, all subsequent positions are rejected in speculative +decoding. Decay aligns the training loss with what matters for acceptance rate. + +```text +weight[k] = exp(-(k-1).clamp(min=0) / gamma) for k = 0..B-1 +``` + +Positions 0 (anchor, excluded by loss mask) and 1 get full weight (1.0). Later positions +decay: e.g., with `gamma=4` and `block_size=8`, position 7 contributes only 22% as +much as position 1. Paper recommendation: gamma=7 for block_size=16, gamma=4 for block_size=8. + +Note: this is different from EAGLE3's `eagle_loss_decay_factor` which multiplies loss by +`alpha^step` across TTT steps. DFlash decay operates within a single block, weighting +early positions higher because they gate acceptance of all later positions. + +### Checkpoint Resume + +DFlash supports checkpoint resume transparently. The `DFlashModule._apply()` method +handles meta-tensor rotary buffers that arise during ModelOpt checkpoint restore — no +special resume logic needed in the training script. + +### Export + +```bash +python scripts/export_hf_checkpoint.py \ + --model_path /path/to/training/output \ + --export_path /path/to/exported/model +``` + +Exports to z-lab compatible HF format (`config.json` + `model.safetensors`). + +## Results (Qwen3-8B) + +Trained on nvidia/Nemotron-Post-Training-Dataset-v2 (2M samples), 64 GPUs, 10 epochs. + +### Training Configuration + +| Parameter | Value | +|-----------|-------| +| Block Size | 8 | +| Sequence Length | 4096 | +| Anchors | 512 | +| Loss | KD + decay (gamma=4) | +| Total Steps | 306,620 | +| Final Per-Token Acc | 67.0% | + +### HuggingFace AR Evaluation + +AR is evaluated using `ar_validate.py` which calls `pseudo_speculative_generate` +with online (context-dependent) ground truth: + +1. Run base model on `input_ids` → get base token + hidden states +2. Build draft block: `[base_token, MASK, MASK, ...]` +3. Run DFlash draft forward → get `block_size-1` draft tokens +4. Verify each draft token against the base model's prediction **given the + accepted sequence so far** (not a pre-computed fixed reference) +5. Accept consecutive matches, append target's correction on first mismatch +6. AR = total accepted tokens / number of speculative steps + +```bash +python scripts/ar_validate.py --model_path /path/to/checkpoint --per_category --osl 512 --steps 7 +``` + +### vLLM Deployment Results + +vLLM nightly (v0.19.1+), H100, MT-Bench 80 prompts, 1024 max tokens: + +| | Baseline | z-lab (bs16) | **ModelOpt (bs8)** | +|---|---------|-------------|-------------------| +| TP=1 tok/s | 145 | 422 | **443** | +| TP=8 tok/s | 377 | 919 | **1053** | +| Speedup (TP=1) | 1.0x | 2.9x | **3.1x** | + +**Per-Category (TP=8):** + +| Category | ModelOpt Accept | z-lab Accept | ModelOpt TPS | z-lab TPS | +|----------|----------------|-------------|-------------|-----------| +| math | **5.14** | 4.24 | **1238** | 1098 | +| coding | **4.03** | 3.52 | **1299** | 1269 | +| writing | **3.99** | 3.97 | **1002** | 903 | +| reasoning | **3.89** | 3.49 | **1188** | 1020 | +| roleplay | **3.88** | 3.37 | **1069** | 923 | +| extraction | **3.60** | 3.02 | **1002** | 789 | +| stem | 3.55 | **3.63** | **1027** | 914 | +| humanities | **3.05** | 2.68 | **786** | 672 | +| **ALL** | | | **1053** | 919 | + +ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories. + +### Key Findings + +| Finding | Evidence | +|---------|----------| +| 3.1x speedup over baseline (TP=1) | 443 vs 145 tok/s on vLLM | +| 15% faster than z-lab | TP=1: 443 vs 422; TP=8: 1053 vs 919 | +| More efficient drafting | 44% vs 16.5% draft acceptance; fewer tokens drafted, more accepted | +| Loss decay boosts AR | +0.12 AR at 55K (gamma=7, bs16); consistent across checkpoints | +| Longer sequences help | seq=4096 vs 512: +0.49 AR on AA-Synthetic | + +## Open Items + +### Not Yet Implemented + +- **Offline training**: DFlash needs multi-layer hidden states at all positions for KV + injection (5x storage vs EAGLE3's single-layer approach). Possible approaches: store + fused hidden states, pre-sample anchors, or hybrid CPU base + GPU draft. +- **Qwen3MoE draft**: Replace `Qwen3MLP` with `Qwen3MoeMLP` via config flag. See + `hf_dflash.py` module docstring for instructions. +- **MLA support (DeepseekV3/Kimi-K2)**: Requires MLA-aware KV injection with compressed K/V. +- **Docker local testing**: Launcher example requires Slurm. Need a local Docker example + with `hf_local=` path mapping. + +### Implemented but Not Yet Validated End-to-End + +- **Sliding window attention**: Code reads `config.layer_types` and sets `sliding_window` + per layer. Unit tested but not validated in a full training run with sliding window models. +- **FP8 / NVFP4 quantization**: Export pipeline supports quantized checkpoints via + `hf_ptq.py` (PTQ succeeded in testing). AR impact of quantization not yet measured. + The flow: train (bf16) → `mtq.quantize(model, quant_cfg)` → `export_hf_checkpoint.py`. +- **Checkpoint resume**: `DFlashModule._apply()` handles meta-tensor rotary buffers + (one-shot check on first `.to(device)` call). Validated in train+resume E2E tests. + +### Validated + +- **Online training**: E2E pipeline (train → export → eval) on sample-1K and sample-10K. +- **Multi-node DDP**: 8-node (64 GPU) training on full dataset, 10 epochs. +- **AR evaluation**: `ar_validate.py` with online GT, per-category MT-Bench. +- **vLLM deployment**: Speculative decoding with `vllm/vllm-openai:nightly` (v0.19.1+). + 3.1x speedup over baseline. Per-category benchmarks on MT-Bench. + + ```bash + vllm serve Qwen/Qwen3-8B \ + --speculative-config '{"method": "dflash", "model": "path/to/checkpoint", "num_speculative_tokens": 7}' \ + --max-num-batched-tokens 32768 + ``` + +- **Export**: z-lab compatible HF format, loadable by vLLM and z-lab benchmark. +- **Loss decay**: Validated +0.12 AR improvement with gamma=7 (bs16). diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 99c8ef4e03..2b08ec8096 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -137,10 +137,11 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: return batch -def make_eagle_supervised_data_module( +def make_speculative_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, train_len=None, + answer_only_loss=False, ) -> dict: if data_args.offline_data_path is None: train_dataset = ShardedDataset("json", data_files=data_args.data_path) @@ -150,6 +151,7 @@ def make_eagle_supervised_data_module( tokenizer=tokenizer, train_len=train_len, return_labels=True, + answer_only_loss=answer_only_loss, ) else: data_collator = VisionLanguageDataCollator( @@ -205,6 +207,17 @@ def on_log(self, args, state, control, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: return control average_acc = np.mean(state.training_accs, axis=0) + # Always print accuracy to console + try: + acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten()) + print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]") + except Exception: + print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}") + # Log accuracy to HF Trainer's logs dict (picked up by TensorBoard) + logs = kwargs.get("logs") or {} + for i, draft_acc in enumerate(average_acc): + for j, step_acc in enumerate(draft_acc): + logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc) if self.estimate_ar: # Calculate mean training AR since last log # NOTE: This is only an estimate of the real AR. @@ -218,19 +231,12 @@ def on_log(self, args, state, control, **kwargs): acc_cumprod *= draft_acc[-1] est_ar += acc_cumprod print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") + logs["estimated_training_ar"] = est_ar # log to wandb - if wandb and is_master(): - logs = kwargs.get("logs") or {} + if hasattr(wandb, "init") and is_master(): if logs: wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step) - for i, draft_acc in enumerate(average_acc): - for j, step_acc in enumerate(draft_acc): - wandb.log( - {f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step - ) - if self.estimate_ar: - wandb.log({"estimated_training_ar": est_ar}, step=state.global_step) # reset training_accs state.training_accs = [] @@ -250,7 +256,7 @@ def on_step_end(self, args, state, control, **kwargs): device=kwargs["model"].device, ) print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb and is_master(): + if hasattr(wandb, "init") and is_master(): wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) except Exception: print_rank_0("AR validation not available.") diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 694aa3303f..5cee98fb51 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -40,7 +40,7 @@ from eagle_utils import ( EagleTrainerWithAccLog, EagleTrainingPlot, - make_eagle_supervised_data_module, + make_speculative_data_module, patch_ring_attention_for_ttt, ) from omegaconf import OmegaConf @@ -103,7 +103,7 @@ class TrainingArguments(transformers.TrainingArguments): ) }, ) - mode: Literal["eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa", "dflash"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} ) @@ -133,8 +133,8 @@ def _parse_cli() -> tuple[str, list[str]]: return args.config, overrides -def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict]: - """Load training config from a YAML file with sections: model, data, training, eagle. +def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict, dict]: + """Load training config from a YAML file with sections: model, data, training, eagle/dflash. *overrides* are OmegaConf dotlist entries (e.g. ``["model.model_name_or_path=xxx"]``) applied on top of the YAML. @@ -142,15 +142,16 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic Returns: hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() + dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert() """ merged = OmegaConf.load(config_path) if overrides: merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides))) cfg = OmegaConf.to_container(merged, resolve=True) - # Eagle section maps directly to EagleConfig fields — no field enumeration needed. - # eagle_architecture_config is a nested dict and is included as-is. + # Eagle/DFlash sections map directly to config fields — no field enumeration needed. eagle_cfg = cfg.get("eagle", {}) + dflash_cfg = cfg.get("dflash", {}) hf_cfg = { **cfg.get("model", {}), @@ -162,12 +163,12 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic cp_size = hf_cfg.get("cp_size", 1) hf_cfg["dp_shard_size"] = torch.cuda.device_count() // cp_size - return hf_cfg, eagle_cfg + return hf_cfg, eagle_cfg, dflash_cfg def train(): config_path, overrides = _parse_cli() - hf_cfg, eagle_cfg = _load_config(config_path, overrides) + hf_cfg, eagle_cfg, dflash_cfg = _load_config(config_path, overrides) parser = transformers.HfArgumentParser( ( @@ -193,7 +194,10 @@ def train(): patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 training_args.parallelism_config.sp_backend = None - print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, eagle_cfg={eagle_cfg}") + print_rank_0( + f"arguments: {model_args}, {training_args}, {medusa_args}, " + f"eagle_cfg={eagle_cfg}, dflash_cfg={dflash_cfg}" + ) # Detect checkpoint to resume from last_checkpoint = ( @@ -251,13 +255,19 @@ def train(): ) model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") + elif training_args.mode == "dflash": + # dflash_cfg maps directly to DFlashConfig fields. + mtsp.convert(model, [("dflash", dflash_cfg)]) else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") - if training_args.mode == "eagle3": - data_module = make_eagle_supervised_data_module( - tokenizer, data_args, train_len=training_args.training_seq_len + if training_args.mode in ("eagle3", "dflash"): + data_module = make_speculative_data_module( + tokenizer, + data_args, + train_len=training_args.training_seq_len, + answer_only_loss=(training_args.mode == "dflash"), ) trainer = EagleTrainerWithAccLog( diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index 1ad7bec409..7e8e661e0c 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -13,7 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""AR validation for speculative decoding models (EAGLE3, DFlash, Medusa). + +Supports per-category MT-Bench evaluation and online (context-dependent) validation. +""" + import argparse +from collections import defaultdict from accelerate import Accelerator from datasets import load_dataset @@ -27,52 +33,66 @@ mto.enable_huggingface_checkpointing() -def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None): +def validate_ar( + model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None, +): + """Validate acceptance rate on MT-Bench prompts using online validation. + + Online validation recomputes ground truth after each accepted draft token + (context-dependent), matching actual speculative decoding behavior. + + Args: + model: Speculative decoding model (EAGLE3, DFlash, etc.) + tokenizer: Tokenizer for the model. + ds: MT-Bench dataset (HuggingFace dataset with 'prompt' and optional 'category'). + steps: Number of draft tokens per speculative step. + osl: Output sequence length. + num_samples: Max number of samples to evaluate. + device: Device to run on. + + Returns: + List of (category, ar) tuples. + """ validator = HFARValidation(model, tokenizer) num_samples = min(num_samples, len(ds)) - ars = [] + results = [] for i in tqdm(range(num_samples), desc="Validating AR"): prompt = ds[i]["prompt"][0] - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - # Apply chat template to the prompt, continuing with assistant response + category = ds[i].get("category", "unknown") if hasattr(tokenizer, "apply_chat_template"): - chat_messages = [ - {"role": "user", "content": prompt}, - ] + chat_messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) - input_ids = tokenizer(prompt, return_tensors="pt").input_ids + input_ids = tokenizer(prompt, return_tensors="pt").input_ids if device: input_ids = input_ids.to(device) - # validate AR - _, ar = validator.validate(osl, input_ids=input_ids, steps=steps) - ars.append(ar) - return ars + try: + _, ar = validator.validate_online(osl, input_ids=input_ids, steps=steps) + results.append((category, ar)) + except Exception: + pass + return results def main(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="AR validation for speculative decoding models.") parser.add_argument("--model_path", type=str, required=True, help="Path to model directory") parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code") - parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation") - parser.add_argument( - "--osl", type=int, default=32, help="Output sequence length for AR validation" - ) - parser.add_argument( - "--num_samples", type=int, default=80, help="Number of MT-Bench samples to use" - ) + parser.add_argument("--steps", type=int, default=3, help="Draft tokens per step") + parser.add_argument("--osl", type=int, default=32, help="Output sequence length") + parser.add_argument("--num_samples", type=int, default=80, help="Number of samples") + parser.add_argument("--per_category", action="store_true", help="Report per-category AR") parser.add_argument( "--ar_lower_bound", type=float, default=None, - help="AR lower bound for validation. If provided, will throw error if AR is below threshold.", + help="Error if AR is below this threshold.", ) args = parser.parse_args() accelerator = Accelerator() - # Load model and tokenizer model = load_vlm_or_llm( args.model_path, device_map="auto", trust_remote_code=args.trust_remote_code ) @@ -82,26 +102,37 @@ def main(): model.eval() model = accelerator.prepare(model) - # Load MT-Bench prompts from HuggingFace ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] - ars = validate_ar( - model, tokenizer, ds, args.steps, args.osl, args.num_samples, accelerator.device + results = validate_ar( + model, + tokenizer, + ds, + args.steps, + args.osl, + args.num_samples, + accelerator.device, ) - # Optionally, throw error if AR is below lower bound - if args.ar_lower_bound: - mean_ar = sum(ars) / len(ars) - if mean_ar < args.ar_lower_bound: + + if results and accelerator.is_main_process: + all_ars = [ar for _, ar in results] + avg_ar = sum(all_ars) / len(all_ars) + print(f"\n==== AR Validation Results (osl={args.osl}, steps={args.steps}) ====") + + if args.per_category: + cat_ars = defaultdict(list) + for cat, ar in results: + cat_ars[cat].append(ar) + for cat in sorted(cat_ars): + cat_avg = sum(cat_ars[cat]) / len(cat_ars[cat]) + print(f" {cat:>12}: {cat_avg:.4f}") + + print(f" {'ALL':>12}: {avg_ar:.4f}") + print(f" Samples: {len(results)}") + + if args.ar_lower_bound and avg_ar < args.ar_lower_bound: raise ValueError( - f"AR is below lower bound {args.ar_lower_bound}. Mean AR: {mean_ar:.4f}" + f"AR {avg_ar:.4f} is below lower bound {args.ar_lower_bound}." ) - # Print results - if ars and accelerator.is_main_process: - avg_ar = sum(ars) / len(ars) - print("\n==== AR Validation Results on MT-Bench ====") - print(f"Number of samples: {len(ars)}") - print(f"Output Sequence Length: {args.osl}") - print(f"Steps: {args.steps}") - print(f"Average AR: {avg_ar:.4f}") if __name__ == "__main__": diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index 2771ab1513..c3ca75cc24 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -44,5 +44,8 @@ def parse_args(): ) model.eval() with torch.inference_mode(): - export_speculative_decoding(model, export_dir=args.export_path) + export_speculative_decoding( + model, + export_dir=args.export_path, + ) print(f"Exported checkpoint to {args.export_path}") diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index aca19a1580..82adea89df 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -27,7 +27,7 @@ from .hf_spec_configs import kimik2_eagle_template_config, llama_eagle_template_config -ALL_SPEC_MODES = ["eagle"] +ALL_SPEC_MODES = ["eagle", "dflash"] LLAMA_EAGLE_SINGLE_LAYER = { "required": { @@ -243,3 +243,107 @@ def _extract_state_dict(self, full_state_dict: dict): export_sd.pop(f"parallel_draft_heads.medusa_heads.{i}.{j}.linear.bias") ) return export_sd + + +class DFlashExporter(SpeculativeDecodingExporter): + """Draft model exporter for DFlash. + + Exports in z-lab compatible format: + - model.safetensors: draft module weights (no prefix) + - config.json: Qwen3-style config with dflash_config field + """ + + def _extract_state_dict(self, full_state_dict: dict): + """Extract DFlash module weights, stripping the dflash_module prefix.""" + export_sd = {} + for key, value in full_state_dict.items(): + if "dflash_module." in key: + export_key = key.split("dflash_module.", 1)[1] + # Skip rotary embedding buffers (not needed, recomputed) + if "rotary_emb" in export_key: + continue + export_sd[export_key] = value.clone() + return export_sd + + def _export_config(self): + """Build config.json matching z-lab DFlash format.""" + model = self.model + base_config = ( + getattr(model.config, "text_config", None) + or getattr(model.config, "llm_config", None) + or model.config + ) + draft_config = model.dflash_config + + config = { + "architectures": ["DFlashDraftModel"], + "model_type": getattr(base_config, "model_type", "qwen3"), + "block_size": model.dflash_block_size, + "dflash_config": { + "mask_token_id": model.mask_token_id, + "target_layer_ids": list(model.target_layer_ids), + }, + # Architecture dimensions + "hidden_size": getattr(draft_config, "hidden_size", base_config.hidden_size), + "num_hidden_layers": draft_config.num_hidden_layers, + "num_attention_heads": getattr( + draft_config, "num_attention_heads", base_config.num_attention_heads + ), + "num_key_value_heads": getattr( + draft_config, "num_key_value_heads", base_config.num_key_value_heads + ), + "head_dim": getattr( + draft_config, + "head_dim", + base_config.hidden_size // base_config.num_attention_heads, + ), + "intermediate_size": getattr( + draft_config, "intermediate_size", base_config.intermediate_size + ), + "hidden_act": getattr(draft_config, "hidden_act", "silu"), + "rms_norm_eps": getattr(draft_config, "rms_norm_eps", 1e-6), + "vocab_size": base_config.vocab_size, + "max_position_embeddings": getattr(base_config, "max_position_embeddings", 32768), + "initializer_range": getattr(base_config, "initializer_range", 0.02), + "attention_bias": getattr(draft_config, "attention_bias", False), + "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), + "rope_theta": getattr( + draft_config, "rope_theta", getattr(base_config, "rope_theta", 1000000.0) + ), + "rope_scaling": getattr(base_config, "rope_scaling", None), + "tie_word_embeddings": False, + "torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace( + "torch.", "" + ), + "num_target_layers": getattr(base_config, "num_hidden_layers", 36), + } + + # Add layer_types if present (Qwen3-style) + if hasattr(draft_config, "layer_types"): + config["layer_types"] = draft_config.layer_types + else: + config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers + + return config + + def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): + """Export the DFlash draft model to deployment format.""" + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + # Export state dict + full_sd = self.model.state_dict() + drafter_sd = self._extract_state_dict(full_sd) + if dtype is not None: + drafter_sd = {k: v.to(dtype) for k, v in drafter_sd.items()} + save_file(drafter_sd, f"{export_dir}/model.safetensors") + + # Export config + drafter_config = self._export_config() + with open(f"{export_dir}/config.json", "w") as f: + json.dump(drafter_config, f, indent=2) + + print( + f"Exported DFlash draft model: {len(drafter_sd)} tensors, " + f"config keys: {list(drafter_config.keys())[:5]}..." + ) diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 69491c6599..5202865efb 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -46,6 +46,61 @@ } +def _get_dflash_default_config(): + from .dflash.default_config import default_dflash_config + + return default_dflash_config + + +DFLASH_DEFAULT_CFG = { + "algorithm": "dflash", + "config": { + "dflash_architecture_config": {}, # merged with default at convert time + }, +} + + +class DFlashConfig(ModeloptBaseConfig): + """DFlash config for block-wise parallel speculative decoding.""" + + dflash_block_size: int = ModeloptField( + default=16, + description="Block size for parallel prediction. Draft predicts this many tokens per block.", + ) + + dflash_freeze_base_model: bool = ModeloptField( + default=True, description="Whether to freeze base model during DFlash module training." + ) + + dflash_self_logit_distillation: bool = ModeloptField( + default=True, description="Whether to use logit distillation from base model." + ) + + dflash_loss_decay_factor: float = ModeloptField( + default=0.0, + description="Gamma for exponential loss decay weighting (paper Eq.4). " + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.", + ) + + dflash_num_anchors: int = ModeloptField( + default=512, + description="Number of random anchor positions sampled per sequence during training.", + ) + + dflash_report_acc: bool = ModeloptField( + default=True, description="Whether to report eval accuracy." + ) + + dflash_architecture_config: dict = ModeloptField( + default={}, description="Config for the DFlash draft module architecture." + ) + + dflash_use_torch_compile: bool = ModeloptField( + default=True, + description="Whether to use torch.compile on DFlash forward/loss methods.", + ) + + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/__init__.py b/modelopt/torch/speculative/dflash/__init__.py new file mode 100644 index 0000000000..912b8d47a2 --- /dev/null +++ b/modelopt/torch/speculative/dflash/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash Optimization Method.""" + +from .conversion import * +from .default_config import * +from .dflash_model import * diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py new file mode 100644 index 0000000000..943be90ca0 --- /dev/null +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash conversion/restore utilities.""" + +from torch import nn + +from modelopt.torch.opt.conversion import ModelLikeModule +from modelopt.torch.opt.dynamic import _DMRegistryCls +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict + +from ..config import DFlashConfig + +DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry + + +def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: + """Convert the model to a DFlash model as per `config`.""" + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + original_cls = type(model) + if original_cls not in DFlashDMRegistry: + for cls in DFlashDMRegistry._registry: + if issubclass(original_cls, cls): + DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls]) + break + + # merge custom config with default config (lazy import to avoid circular) + from .default_config import default_dflash_config + + custom_config = config.dflash_architecture_config + config.dflash_architecture_config = {**default_dflash_config, **custom_config} + + dflash_model = DFlashDMRegistry.convert(model) + dflash_model.modify(config) + + metadata = {} + return dflash_model, metadata + + +def restore_dflash_model( + model: nn.Module, config: DFlashConfig, metadata: MetadataDict +) -> nn.Module: + """Function for restoring a previously converted model to a DFlash model.""" + assert not metadata, "No metadata expected!" + return convert_to_dflash_model(model, config)[0] diff --git a/modelopt/torch/speculative/dflash/default_config.py b/modelopt/torch/speculative/dflash/default_config.py new file mode 100644 index 0000000000..1777de3f29 --- /dev/null +++ b/modelopt/torch/speculative/dflash/default_config.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default DFlash architecture config. + +Model-specific settings (hidden_size, num_attention_heads, rope_*, etc.) +are inherited from the base model in HFDFlashModel.modify(). Static +defaults that don't depend on the base model are set here, similar to +``eagle/default_config.py``. +""" + +default_dflash_config = { + # DFlash-specific + "num_hidden_layers": 5, + # Architecture defaults (overridable by user config) + "hidden_act": "silu", + "rms_norm_eps": 1e-06, + "initializer_range": 0.02, + "attention_bias": False, + "attention_dropout": 0.0, + "tie_word_embeddings": False, + "_attn_implementation": "sdpa", +} diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py new file mode 100644 index 0000000000..0a10f065eb --- /dev/null +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash model to support block-wise parallel speculative decoding.""" + +from modelopt.torch.opt.dynamic import DynamicModule + + +class DFlashModel(DynamicModule): + """Base DFlash Model.""" + + def _setup(self): + """Register temporary attributes for the DFlash module.""" + self._register_temp_attribute("dflash_module", None) + + def modify(self, config): + """Base DFlash Model modify function. Child class should implement the details.""" + self.dflash_block_size = config.dflash_block_size + self.dflash_freeze_base_model = config.dflash_freeze_base_model + self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_self_logit_distillation = config.dflash_self_logit_distillation + self.dflash_num_anchors = config.dflash_num_anchors + self.dflash_report_acc = config.dflash_report_acc + self.dflash_use_torch_compile = config.dflash_use_torch_compile diff --git a/modelopt/torch/speculative/mode.py b/modelopt/torch/speculative/mode.py index 866449e155..ae965354a9 100644 --- a/modelopt/torch/speculative/mode.py +++ b/modelopt/torch/speculative/mode.py @@ -23,7 +23,8 @@ _ModeRegistryCls, ) -from .config import EagleConfig, MedusaConfig +from .config import DFlashConfig, EagleConfig, MedusaConfig +from .dflash.conversion import convert_to_dflash_model, restore_dflash_model from .eagle.conversion import convert_to_eagle_model, restore_eagle_model from .medusa.conversion import convert_to_medusa_model, restore_medusa_model @@ -58,6 +59,34 @@ def restore(self) -> RestoreEntrypoint: return restore_medusa_model +@SpeculativeDecodingModeRegistry.register_mode +class DFlashModeDescriptor(ModeDescriptor): + """Class to describe the ``"dflash"`` mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "dflash" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return DFlashConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_dflash_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_dflash_model + + @SpeculativeDecodingModeRegistry.register_mode class EagleModeDescriptor(ModeDescriptor): """Class to describe the ``"eagle"`` mode. diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index 5e3f4bff2f..d59aed37d5 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -31,3 +31,6 @@ with import_plugin("transformers"): from .transformers import * + +with import_plugin("hf_dflash"): + from .hf_dflash import * diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py new file mode 100644 index 0000000000..b7c0fa91f6 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -0,0 +1,890 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash speculative decoding plugin for HuggingFace models. + +Matches the reference SpecForge implementation (github.com/sgl-project/SpecForge). + +Architecture: +- Feature Fusion: multi-layer target hidden states → FC + RMSNorm +- KV Injection: fused features as K/V in every draft layer with QK-norm +- Parallel Drafting: mask_token_id for unknown positions, bidirectional within blocks +- Random anchor sampling with exponential loss decay +- Logit distillation from target model + +Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) + +Draft model components: + The draft model currently uses Qwen3 components (MLP, RMSNorm, RotaryEmbedding) + from ``transformers.models.qwen3``, matching z-lab's reference checkpoint format. + Qwen3 sliding window attention is supported via ``config.layer_types``. + The draft architecture is independent of the target model — any target model can + be used as long as it provides hidden states. + + To add support for other draft architectures: + + Qwen3MoE (MoE MLP): + 1. Import ``Qwen3MoeMLP`` from ``transformers.models.qwen3_moe`` + 2. Add a config flag (e.g., ``use_moe``) in ``dflash_architecture_config`` + 3. In ``DFlashDecoderLayer.__init__``, select MLP based on the flag + RMSNorm, RotaryEmbedding, and attention are shared across Qwen3 variants. + + MLA (Multi-head Latent Attention, e.g., DeepseekV3/Kimi-K2): + MLA compresses K/V into a low-rank latent space. To support MLA in DFlash: + 1. Replace ``DFlashAttention`` with an MLA-aware variant that handles + compressed KV injection (project target_hidden through MLA's down/up + projections before concatenating with noise K/V) + 2. Handle lazy rope initialization (see ``_setup_kimi_k2_decoder`` in + ``modelopt.torch.speculative.utils`` for the EAGLE3 approach) + 3. The ``_apply`` meta buffer fix in ``DFlashModule`` already handles the + lazy rope pattern needed for MLA models. +""" + +import logging + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.trainer_pt_utils import LabelSmoother + +logger = logging.getLogger(__name__) + +# DFlash draft model uses Qwen3 components regardless of the target model. +# This matches z-lab's implementation which inherits from Qwen3PreTrainedModel. +from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS # noqa: E402, N814 +from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS # noqa: E402, N814 +from transformers.models.qwen3.modeling_qwen3 import ( # noqa: E402 + Qwen3RotaryEmbedding as _ROTARY_CLS, # noqa: N814 +) +from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half # noqa: E402 +from transformers.utils import ModelOutput # noqa: E402 + +from ..dflash.conversion import DFlashDMRegistry # noqa: E402 +from ..dflash.dflash_model import DFlashModel # noqa: E402 + +__all__ = ["HFDFlashModel"] + + +def build_target_layer_ids(num_target_layers, num_draft_layers): + """Select layers uniformly from the target model for feature extraction.""" + if num_draft_layers == 1: + return [num_target_layers // 2] + start = 1 + end = num_target_layers - 3 + span = end - start + return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] + + +def apply_rotary_pos_emb(q, k, cos, sin): + """Apply RoPE. Q uses last q_len positions, K uses all positions.""" + cos = cos.unsqueeze(1) # [B, 1, seq, dim] + sin = sin.unsqueeze(1) + q_len = q.size(2) + q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class DFlashAttention(nn.Module): + """Attention with KV injection, using HF's attention dispatch for exact SpecForge parity.""" + + def __init__(self, config, layer_idx): + """Initialize DFlash attention with KV injection projections and QK-norm.""" + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + self.is_causal = False + + attn_bias = getattr(config, "attention_bias", False) + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias) + self.k_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias) + + self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + + # Resolve HF attention function matching SpecForge's dispatch + self._attn_fn = None + # Qwen3 uses sliding window attention on some layers (config.layer_types) + if hasattr(config, "layer_types") and hasattr(config, "sliding_window"): + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None + else: + self.sliding_window = None + + def _get_attn_fn(self): + """Lazily resolve the HF attention function (default: sdpa).""" + if self._attn_fn is not None: + return self._attn_fn + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + impl = self.config._attn_implementation # default set in dflash/default_config.py + self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"]) + return self._attn_fn + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward with KV injection. + + Q is projected from the noise block (draft token embeddings: [anchor, mask, mask, ...]). + K and V are projected from the concatenation of target hidden states (context from the + base model) and noise block, so the draft can attend to both context and its own block. + """ + bsz, q_len, _ = hidden_states.shape + ctx_len = target_hidden.shape[1] + + # Q from noise block only (the draft tokens being predicted), with QK-norm + q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + + # K from context + noise, with QK-norm + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + + # V from context + noise (no norm) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + v = ( + torch.cat([v_ctx, v_noise], dim=1) + .view(bsz, ctx_len + q_len, -1, self.head_dim) + .transpose(1, 2) + ) + + # RoPE + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # Use HF's attention dispatch (handles GQA internally) + attn_fn = self._get_attn_fn() + attn_output, _ = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + return self.o_proj(attn_output) + + +class DFlashDecoderLayer(nn.Module): + """Draft decoder layer with KV injection.""" + + def __init__(self, config, layer_idx): + """Initialize decoder layer with attention, MLP, and layer norms.""" + super().__init__() + self.self_attn = DFlashAttention(config, layer_idx) + self.mlp = _MLP_CLS(config) + self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward pass with residual connections.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, target_hidden, position_embeddings, attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DFlashModule(nn.Module): + """DFlash draft module matching SpecForge DFlashDraftModel.""" + + def __init__(self, config): + """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" + super().__init__() + self.config = config + self.block_size = config.block_size + + # Feature fusion + num_fused_layers = len(config.target_layer_ids) + self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + # Decoder layers + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = _ROTARY_CLS(config=config) + self._rotary_config = config # Stored for re-creating rotary_emb on resume + + # Explicit weight init is needed because DFlashModule is instantiated via + # mtsp.convert() AFTER the base model's post_init() has already run, so HF's + # automatic _init_weights walk doesn't reach these new layers. + self._init_weights(config) + + def _apply(self, fn, recurse=True): + """Fix meta-tensor rotary buffers before device transfer. + + On checkpoint resume, inv_freq (a computed buffer, not saved in checkpoint) + stays on meta device. Re-create rotary_emb on CPU so .to(device) can proceed. + """ + if hasattr(self, "rotary_emb") and any(b.is_meta for b in self.rotary_emb.buffers()): + self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device="cpu") + return super()._apply(fn, recurse) + + def _init_weights(self, config): + """Initialize weights matching HF PreTrainedModel._init_weights.""" + std = getattr(config, "initializer_range", 0.02) + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): + """Forward matching SpecForge DFlashDraftModel.forward.""" + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) + + return self.norm(hidden_states) + + +def create_dflash_attention_mask( + seq_len, block_size, device, dtype +): # Legacy: used for inference only + """Create [L, 2L] attention mask matching SpecForge. + + Context (cols 0..L-1): Block B sees blocks 0..B-1 (strictly previous). + Noise (cols L..2L-1): causal within same block only. + """ + indices = torch.arange(seq_len, device=device) + block_ids = indices // block_size + + q_block_ids = block_ids.unsqueeze(1) # [L, 1] + k_block_ids = block_ids.unsqueeze(0) # [1, L] + + ctx_mask = k_block_ids < q_block_ids + same_block = q_block_ids == k_block_ids + causal = indices.unsqueeze(0) >= indices.unsqueeze(1) # matching SpecForge: j >= i + noise_mask = same_block & causal + + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) + + # Create additive mask directly in target dtype, matching EAGLE convention. + full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=dtype) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(dtype).min) + + return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] + + +def create_dflash_loss_mask(seq_len, block_size, device): # Legacy: used for inference only + """Create loss mask: exclude Block 0 and block starts.""" + positions = torch.arange(seq_len, device=device) + block_ids = positions // block_size + is_block_0 = block_ids == 0 + is_block_start = (positions % block_size) == 0 + return (~is_block_0 & ~is_block_start).float() + + +@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class HFDFlashModel(DFlashModel): + """DFlash Model matching SpecForge OnlineDFlashModel.""" + + @property + def _base_model(self): + return self.get_submodule(self.base_model_path) + + @property + def _base_model_embeddings(self): + return self.get_submodule(self.base_model_embeddings_path) + + @property + def _base_model_lm_head(self): + return self.get_submodule(self.base_model_lm_head_path) + + @property + def _base_llm_config(self): + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) + + @staticmethod + def _auto_detect_mask_token_id(base_config): + """Auto-detect mask token ID from the base model's tokenizer. + + Loads the tokenizer and returns ``tokenizer.mask_token_id`` if available. + Raises ValueError otherwise — the user must set mask_token_id explicitly. + """ + from transformers import AutoTokenizer + + model_name = getattr(base_config, "_name_or_path", None) + if model_name: + try: + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if tokenizer.mask_token_id is not None: + return tokenizer.mask_token_id + except Exception: + pass + + raise ValueError( + "Cannot auto-detect mask_token_id. " + "Please set dflash_architecture_config.mask_token_id explicitly in your config. " + "The mask token should be an unused special token (not eos or pad)." + ) + + def _find_base_model_parts(self): + """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.""" + for name, paths in { + "base_model_path": ["model.language_model", "model", "backbone"], + "base_model_embeddings_path": [ + "model.embed_tokens", + "backbone.embeddings", + "model.language_model.embed_tokens", + ], + "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + }.items(): + for path in paths: + try: + submodule = self.get_submodule(path) + assert isinstance(submodule, torch.nn.Module) + setattr(self, name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") + + def modify(self, config): + """Initialize DFlash draft module.""" + super().modify(config) + + base_config = self._base_llm_config + self.dflash_config = PretrainedConfig.from_dict(config.dflash_architecture_config) + + # hidden_size and vocab_size MUST match the base model. + self.dflash_config.hidden_size = base_config.hidden_size + self.dflash_config.vocab_size = base_config.vocab_size + + # Inherit architecture settings from base model when not specified by user. + # Static defaults (hidden_act, attention_bias, etc.) are in dflash/default_config.py. + _base_model_attrs = [ + "max_position_embeddings", + "intermediate_size", + "num_attention_heads", + "num_key_value_heads", + "rope_theta", + "rope_scaling", + "rope_type", + "rope_interleaved", + "rms_norm_eps", + ] + for attr in _base_model_attrs: + if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: + if hasattr(base_config, attr): + setattr(self.dflash_config, attr, getattr(base_config, attr)) + + self.dflash_config.head_dim = getattr( + self.dflash_config, + "head_dim", + self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, + ) + self.dflash_config.block_size = self.dflash_block_size + + # Target layer IDs + num_target_layers = base_config.num_hidden_layers + num_draft_layers = self.dflash_config.num_hidden_layers + self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + self.dflash_config.target_layer_ids = self.target_layer_ids + + # mask_token_id resolution order: + # 1. Explicit in dflash_architecture_config (user override) + # 2. Auto-detect from tokenizer (tokenizer.mask_token_id) + # 3. Error — user must provide mask_token_id + mask_id = config.dflash_architecture_config.get("mask_token_id", None) + if mask_id is None: + mask_id = self._auto_detect_mask_token_id(base_config) + self.mask_token_id = mask_id[0] if isinstance(mask_id, list) else mask_id + logger.info("DFlash mask_token_id: %s", self.mask_token_id) + + # Freeze base model + if self.dflash_freeze_base_model: + for param in self.parameters(): + param.requires_grad = False + + self._find_base_model_parts() + + self.dflash_module = DFlashModule(self.dflash_config) + self.dflash_module.to(self._base_model.dtype).to( + next(self._base_model.layers[-1].parameters()).device + ) + + self.is_quantized = False + self._num_anchors = self.dflash_num_anchors + + def get_exporter(self): + """Get the exporter for the DFlash draft model.""" + from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter + + return DFlashExporter(self) + + def _sample_anchor_positions(self, seq_len, loss_mask, device): + """Randomly sample anchor positions per sample, matching SpecForge PR #473. + + Returns (anchor_positions [B, N], block_keep_mask [B, N]). + + TODO: Fix the random seed per epoch (change between epochs) so that anchor + positions are deterministic within an epoch. This would allow caching the derived + masks and position IDs across steps while preserving the same data augmentation + effect. Currently, anchors are re-sampled every forward pass. + """ + bs = self.dflash_block_size + bsz = loss_mask.shape[0] + max_anchor = max(seq_len - bs, 0) + num_anchors = getattr(self, "_num_anchors", 512) + + valid = loss_mask[:, : max_anchor + 1] > 0.5 + valid_counts = valid.sum(dim=1) + max_n = min(num_anchors, int(valid_counts.max().item()) - 1) + + if max_n <= 0: + # No valid anchors — return empty + anchors = torch.zeros(bsz, 1, dtype=torch.long, device=device) + keep = torch.zeros(bsz, 1, dtype=torch.bool, device=device) + return anchors, keep + + indices = torch.arange(max_anchor + 1, device=device).unsqueeze(0).expand(bsz, -1) + masked_indices = torch.where(valid, indices, torch.tensor(seq_len + 1, device=device)) + + random_vals = torch.rand(bsz, max_anchor + 1, device=device) + random_vals = torch.where(valid, random_vals, torch.tensor(2.0, device=device)) + + _, sorted_idx = random_vals.sort(dim=1) + gathered = torch.gather(masked_indices, 1, sorted_idx) + anchors = gathered[:, :max_n].sort(dim=1).values + + keep = torch.arange(max_n, device=device).unsqueeze(0) < valid_counts.unsqueeze(1).clamp( + max=max_n + ) + anchors = torch.where(keep, anchors, torch.tensor(0, dtype=torch.long, device=device)) + return anchors, keep + + def _build_noise_embedding(self, input_ids, anchor_positions, block_keep_mask, n_blocks): + """Build noise embeddings: anchor token at block start, mask_token elsewhere.""" + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + device = input_ids.device + + noise_ids = torch.full( + (bsz, n_blocks * block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_starts = torch.arange(n_blocks, device=device) * block_size + block_starts_exp = block_starts.unsqueeze(0).expand(bsz, -1) + valid_anchors = anchor_positions.clamp(0, seq_len - 1) + anchor_tokens = torch.gather(input_ids, 1, valid_anchors) + batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand(bsz, n_blocks) + noise_ids[batch_idx, block_starts_exp] = torch.where( + block_keep_mask, + anchor_tokens, + torch.tensor(self.mask_token_id, dtype=torch.long, device=device), + ) + return self._base_model_embeddings(noise_ids) + + def _build_position_ids(self, seq_len, anchor_positions, device): + """Build position IDs: context [0..S-1], draft blocks [anchor+0..anchor+B-1].""" + bsz = anchor_positions.shape[0] + block_size = self.dflash_block_size + + ctx_pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + offsets = torch.arange(block_size, device=device).view(1, 1, -1) + draft_pos = (anchor_positions.unsqueeze(-1) + offsets).view(bsz, -1) + return torch.cat([ctx_pos, draft_pos], dim=1) + + def _build_draft_attention_mask( + self, seq_len, anchor_positions, block_keep_mask, n_blocks, dtype, device + ): + """Build SDPA attention mask: context (causal) + draft (bidirectional within block).""" + bsz = anchor_positions.shape[0] + block_size = self.dflash_block_size + q_len = n_blocks * block_size + kv_len = seq_len + q_len + + q_indices = torch.arange(q_len, device=device).view(1, 1, -1, 1) + kv_indices = torch.arange(kv_len, device=device).view(1, 1, 1, -1) + q_block_ids = q_indices // block_size + + anchor_exp = anchor_positions.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) + + # Context: kv < S and kv < anchor + mask_ctx = (kv_indices < seq_len) & (kv_indices < anchor_exp) + # Draft: kv >= S and same block + is_draft = kv_indices >= seq_len + kv_block_ids = (kv_indices - seq_len) // block_size + mask_draft = is_draft & (q_block_ids == kv_block_ids) + # Valid block + valid_block = block_keep_mask.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) + + final_mask = (mask_ctx | mask_draft) & valid_block # [B, 1, Q, KV] + + # Convert bool mask to float additive mask for SDPA + attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=dtype) + attn_mask.masked_fill_(~final_mask, torch.finfo(dtype).min) + return attn_mask + + def _compute_loss( + self, logits, input_ids, anchor_positions, block_keep_mask, loss_mask, base_logits=None + ): + """Compute weighted cross-entropy (or KD) loss and accuracy. + + Args: + logits: Draft model output [B, N*block_size, vocab]. + input_ids: Original input token IDs [B, seq_len]. + anchor_positions: Anchor positions per block [B, N]. + block_keep_mask: Valid block mask [B, N]. + loss_mask: Token-level loss mask [B, seq_len]. + base_logits: Base model logits for KD loss [B, seq_len, vocab], or None for CE. + + Returns: + (loss, accuracy) tuple. + """ + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + n_blocks = anchor_positions.shape[1] + device = input_ids.device + + label_offsets = torch.arange(0, block_size, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets + valid_label = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + + # Weight mask: valid block * in bounds * exclude anchor (pos 0) * loss_mask + weight_mask = block_keep_mask.unsqueeze(-1).expand(-1, -1, block_size).float() + weight_mask = weight_mask * valid_label.float() + pos_in_block = torch.arange(block_size, device=device).view(1, 1, -1) + weight_mask = weight_mask * (pos_in_block > 0).float() + + orig_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + weight_mask = weight_mask * orig_loss_mask + + binary_eval_mask = weight_mask.view(-1) + + # Optional loss decay + if self.dflash_loss_decay_factor > 0: + k = torch.arange(block_size, device=device).view(1, 1, -1) + decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) + weight_mask = weight_mask * decay + + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + flat_weights = weight_mask.view(-1) + valid_count = flat_weights.sum() + 1e-6 + + if valid_count > 1.0: + if base_logits is not None: + # KD loss: teacher logits for token anchor+k are at position anchor+k-1 + teacher_indices = (safe_label_indices - 1).clamp(min=0) + teacher_logits = torch.gather( + base_logits.unsqueeze(1).expand(-1, n_blocks, -1, -1), + 2, + teacher_indices.unsqueeze(-1).expand(-1, -1, -1, base_logits.size(-1)), + ) + flat_teacher = teacher_logits.reshape(-1, base_logits.size(-1)).detach() + target_soft = torch.softmax(flat_teacher, dim=-1) + draft_logsoft = torch.log_softmax(flat_logits, dim=-1) + kd_loss = -(target_soft * draft_logsoft).sum(dim=-1) + loss = (kd_loss * flat_weights).sum() / valid_count + else: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + loss = (loss_per_token * flat_weights).sum() / valid_count + + with torch.no_grad(): + preds = flat_logits.argmax(dim=-1) + correct = (preds == flat_targets) & (binary_eval_mask > 0.5) + accuracy = correct.sum().float() / (binary_eval_mask.sum() + 1e-6) + accuracy = accuracy.item() + else: + loss = flat_logits.sum() * 0.0 + accuracy = 0.0 + + return loss, accuracy + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """Training forward matching SpecForge latest (post-PR #473). + + Key changes from original PR #415: + - Random anchor sampling instead of uniform block division + - Bidirectional intra-block attention (no causal constraint) + - Context sees strictly before anchor position + - Label alignment: position k predicts token at anchor+k + - Optional loss decay weighting + """ + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + bsz, seq_len = input_ids.shape + device = input_ids.device + + # 1. Run base model → hidden states + # TODO: For co-training the base model, remove no_grad and eval() switch. + with torch.no_grad(): + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + + # 2. Build loss mask: combine labels (answer-only) and attention_mask (padding) + loss_mask = torch.ones(bsz, seq_len, device=device) + if labels is not None: + loss_mask = loss_mask * (labels != LabelSmoother.ignore_index).float() + if attention_mask is not None: + loss_mask = loss_mask * attention_mask.float() + + # 3. Random anchor sampling (SpecForge PR #463/#473) + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + if n_blocks == 0 or not block_keep_mask.any(): + # Zero loss that still flows through dflash_module for DDP gradient sync + dummy = self.dflash_module.fc.weight.sum() * 0.0 + return ModelOutput(loss=dummy, logits=base_outputs.logits, train_acc=[[0.0]]) + + # 4-6. Build draft inputs + noise_embedding = self._build_noise_embedding( + input_ids, anchor_positions, block_keep_mask, n_blocks + ) + full_pos = self._build_position_ids(seq_len, anchor_positions, device) + attn_mask = self._build_draft_attention_mask( + seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device + ) + + # 7. Draft forward + hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=full_pos, + attention_mask=attn_mask, + ) + + # 8. Compute loss and accuracy + logits = self._base_model_lm_head(hidden) + loss, accuracy = self._compute_loss( + logits, + input_ids, + anchor_positions, + block_keep_mask, + loss_mask, + base_outputs.logits if self.dflash_self_logit_distillation else None, + ) + + return ModelOutput( + loss=loss, + logits=base_outputs.logits, + train_acc=[[accuracy]], + ) + + @torch.no_grad() + def pseudo_speculative_generate(self, input_ids, steps=1): + """Generate draft tokens using one DFlash block for AR validation. + + This method implements a single speculative decoding step: + + 1. **Base model forward**: Run the full target model on ``input_ids`` to get: + - ``base_token``: greedy next token (argmax of last position logits) + - ``hidden_states``: intermediate hidden states from target layers + + 2. **Extract target hidden states**: Concatenate hidden states from + ``target_layer_ids`` (e.g., layers [1, 9, 17, 25, 33] for 5-layer draft). + Shape: ``[B, seq_len, num_layers * hidden_size]``. + + 3. **Build block input**: Create a block of ``block_size`` tokens where: + - Position 0 = ``base_token`` (the anchor/known token) + - Positions 1..block_size-1 = ``mask_token_id`` (unknown, to be predicted) + Embed this block via the base model's embedding layer. + + 4. **Position IDs**: Context positions ``[0..seq_len-1]`` followed by block + positions ``[seq_len..seq_len+block_size-1]``. The draft model's attention + uses RoPE on these positions so Q (block only) attends to K (context + block) + with correct relative position encoding. + + 5. **Draft forward**: Run ``DFlashModule`` with: + - ``noise_embedding``: embedded block tokens + - ``target_hidden``: extracted hidden states from step 2 + - ``position_ids``: context + block positions + - ``attention_mask=None``: no mask at inference (all positions attend freely) + The draft model's KV injection concatenates projected target_hidden as K/V + with the block's own K/V, enabling the draft to "see" the target's context. + + 6. **Decode**: Apply ``lm_head`` to draft hidden states at positions 1..block_size-1 + (skip position 0 which is the known anchor). Argmax gives draft tokens. + + 7. **Return**: ``(base_token, draft_tokens[:steps])`` — base token is always + returned; draft tokens are truncated to ``steps`` (default: block_size-1). + + Note: + This method re-runs the full target model from scratch on each call + (no KV cache). For AR validation, it is called repeatedly with growing + ``input_ids`` by ``AcceptanceRateValidation.validate()``. The ``steps`` + parameter should be set to ``block_size - 1`` for full block evaluation. + + Args: + input_ids: Input token IDs [B, seq_len]. + steps: Number of draft tokens to return (capped at block_size-1). + + Returns: + base_token: Next token from base model [B, 1]. + draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None if steps < 1. + """ + # Call the base model's inner model directly (avoids DynamicModule dispatch) + model_output = self._base_model( + input_ids=input_ids, + output_hidden_states=True, + ) + # Compute logits via lm_head + base_logits = self._base_model_lm_head(model_output.last_hidden_state) + # Build output with hidden_states + base_outputs = ModelOutput( + logits=base_logits, + hidden_states=model_output.hidden_states, + ) + base_logits = base_outputs.logits + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + + if steps < 1: + return base_token, None + + # Extract target hidden states (raw, before FC projection) + hid_offset = 1 + if not hasattr(self, "_psg_debug"): + self._psg_debug = True + sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + th_dbg = torch.cat(sel, dim=-1) + n_layers = len(base_outputs.hidden_states) + th_norm = th_dbg.norm().item() + logger.info( + "[psg] hidden layers: %d, target_hidden: %s, norm: %.2f", + n_layers, + th_dbg.shape, + th_norm, + ) + logger.info( + "[psg] base_token: %d, mask_token_id: %s", base_token[0].item(), self.mask_token_id + ) + seq_len = input_ids.shape[1] + blk = self.dflash_block_size + logger.info( + "[psg] pos: ctx=[0..%d], blk=[%d..%d]", seq_len - 1, seq_len, seq_len + blk - 1 + ) + selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) + + block_size = self.dflash_block_size + bsz = input_ids.shape[0] + seq_len = input_ids.shape[1] + device = input_ids.device + + # Block: first token is base_token (anchor), rest are mask + block_ids = torch.full( + (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_ids[:, 0] = base_token.squeeze(-1) + noise_embedding = self._base_model_embeddings(block_ids) + + # Position IDs: training uses [0..L-1, 0..L-1] where noise positions + # mirror context positions. At inference, block predicts tokens at + # seq_len..seq_len+B-1, so noise positions continue from ctx_len. + ctx_len = target_hidden.shape[1] + ctx_positions = torch.arange(ctx_len, device=device) + block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) + pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) + + # No attention mask at inference — matching SpecForge's spec_generate + # which uses KV cache with no mask. All positions attend freely to + # context and each other within the block. + + # Draft forward + draft_hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + + # Logits on positions 1..block_size-1 (skip anchor at position 0) + draft_logits = self._base_model_lm_head(draft_hidden[:, 1:, :]) + draft_tokens = draft_logits.argmax(dim=-1) # [B, block_size-1] + + # Return up to `steps` tokens + num_tokens = min(steps, block_size - 1) + return base_token, draft_tokens[:, :num_tokens] diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 9e167c8dc9..ca8bbcd0af 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -376,6 +376,84 @@ def validate( return ground_truth, ar + def validate_online( + self, + osl, + prompt=None, + input_ids=None, + steps=1, + ): + """Validate AR with online (context-dependent) ground truth. + + Instead of pre-computing a fixed ground truth, this method verifies + draft tokens against the target model's response to the current + sequence (including previously accepted draft tokens). This matches + the actual speculative decoding verification loop. + + Args: + osl: output sequence length + prompt: text prompt (alternative to input_ids) + input_ids: tokenized input + steps: number of draft tokens per step + """ + if input_ids is None: + input_ids = self.tokenize(prompt) + + isl = input_ids.shape[1] + max_len = isl + osl + total_accepted = 0 + cnt = 0 + + while input_ids.shape[1] < max_len: + cnt += 1 + + # Generate base token + draft tokens + input_id, draft_tokens = self.model.pseudo_speculative_generate(input_ids, steps=steps) + draft_tokens = self.check_data_consistency_across_ranks(draft_tokens) + input_id = self.check_data_consistency_across_ranks(input_id) + + # Append base token + input_ids = torch.cat((input_ids, input_id), dim=-1) + + if draft_tokens is None or input_ids.shape[1] >= max_len: + total_accepted += 1 # base token + continue + + # Build candidate sequence with draft tokens appended + candidate = torch.cat((input_ids, draft_tokens), dim=-1) + + # Get target model's response to the candidate sequence + with torch.no_grad(): + target_output = self.model._base_model(candidate) + target_logits = self.model._base_model_lm_head(target_output.last_hidden_state) + # posterior[i] = target's prediction given candidate[:i+1] + # For positions where we placed draft tokens, compare + # target's prediction at position i-1 with draft token at i + posterior = target_logits.argmax(dim=-1) + + # Check acceptance: compare draft[i] with posterior at input_ids_len-1+i + accepted = 0 + pos = input_ids.shape[1] - 1 # position of base token in candidate + for i in range(draft_tokens.shape[-1]): + if pos + i >= candidate.shape[1] - 1: + break + if posterior[:, pos + i] == draft_tokens[:, i]: + accepted += 1 + input_ids = torch.cat((input_ids, draft_tokens[:, i : i + 1]), dim=-1) + else: + # Rejected — append target's token instead + input_ids = torch.cat((input_ids, posterior[:, pos + i : pos + i + 1]), dim=-1) + accepted += 1 # target's token counts + break + + if input_ids.shape[1] >= max_len: + break + + total_accepted += 1 + accepted # base token + accepted drafts + + ar = total_accepted / cnt if cnt > 0 else 0.0 + return input_ids, ar + @contextlib.contextmanager def temporary_set_config_value(config, field, value): diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e147ebf2c2..cda1c02ccc 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -33,26 +33,6 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index -def _sharegpt_to_openai_messages(conversations: list[dict]): - """Optionally align sharedgpt format to openai format.""" - role_mapping = { - "user": "user", - "User": "user", - "human": "user", - "assistant": "assistant", - "Assistant": "assistant", - "gpt": "assistant", - "system": "system", - "System": "system", - } - messages = [] - for msg in conversations: - role = role_mapping[msg["role"]] - content = msg["content"] - messages.append({"role": role, "content": content}) - return messages - - class ShardedDataset(torch.utils.data.Dataset): """Subclass of torch.utils.data.Dataset to load data from HuggingFace dataset.""" @@ -153,6 +133,9 @@ def __init__( if self.tokenizer.chat_template is None: raise ValueError("No valid chat template!") + if self.answer_only_loss: + self._ensure_generation_tags() + def _post_process_tokenizer(self): if self.tokenizer.pad_token_id is None: print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.") @@ -171,6 +154,166 @@ def _post_process_chat_template(self): REMOVE_THINK_CHAT_TEMPLATE, "" ) + # Simplified chat templates with {% generation %} tags for answer_only_loss. + # + # PURPOSE: + # HuggingFace's return_assistant_tokens_mask requires {% generation %} / + # {% endgeneration %} tags in the Jinja chat template to identify which tokens + # belong to assistant responses. Many models (Qwen3, Llama3) ship without these + # tags. These simplified templates add them so that answer_only_loss works + # reliably without regex fallbacks. + # + # HOW IT WORKS: + # When answer_only_loss=True, _ensure_generation_tags() detects the model's + # template style (ChatML, Llama3) and replaces the tokenizer's chat_template + # with one of these simplified versions. The {% generation %} tags tell HF + # exactly which tokens are assistant content for loss masking. + # + # WHAT IS PRESERVED: + # - System / user / assistant role formatting (exact token match) + # - Multi-turn conversation structure + # - block injection on last assistant turn (Qwen3-style, chatml_think) + # - Content is output as-is — training data with blocks is handled correctly + # + # WHAT IS DROPPED (vs original model templates): + # - Tool call formatting (tool_call XML tags, function signatures) + # - Multi-step tool response handling + # - reasoning_content vs content splitting logic + # - enable_thinking parameter support + # - VLM/multimodal content handling + # + # LIMITATIONS: + # - Training data with tool_call messages will not be formatted correctly. + # Use the original template with manually added {% generation %} tags for + # tool-use training data. + # - The chatml_think variant adds \n\n\n\n only to the last + # assistant turn (matching Qwen3 behavior). Non-last turns without + # in their content will differ from the original template which also + # conditionally adds think wrappers based on multi-step reasoning context. + # - Only ChatML (<|im_start|>/<|im_end|>) and Llama3 + # (<|start_header_id|>/<|eot_id|>) styles are supported. Other template + # styles fall back to regex-based assistant span detection. + # + # TO USE A CUSTOM TEMPLATE INSTEAD: + # Pass chat_template= to LanguageDataCollator with your own template that + # includes {% generation %}...{% endgeneration %} around assistant content. + _GENERATION_TEMPLATES = { + # Basic ChatML without injection (Phi, older Qwen, generic ChatML) + "chatml": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), + # ChatML with wrapper on last assistant turn (Qwen3-style) + "chatml_think": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{% if loop.last and not message['content'].startswith('') %}" + "{{ '\n\n\n\n' }}" + "{% endif %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), + "llama3": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|start_header_id|>system<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% generation %}" + "{{ message['content'] }}{% endgeneration %}{{ '<|eot_id|>' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" + ), + } + + def _ensure_generation_tags(self): + """Ensure chat template has {% generation %} tags for answer_only_loss. + + If the template already has generation tags, no action taken. + Otherwise, detect the template style and replace with a simplified + version that includes proper generation tags. + """ + template = self.tokenizer.chat_template + if template is None: + return + + if "{% generation %}" in template or "{%generation%}" in template: + return + + # Detect template style and replace with generation-tagged version + old_template = template + if "<|im_start|>" in template and "<|im_end|>" in template: + # Check if original template injects (Qwen3-style) + style = "chatml_think" if "" in template else "chatml" + elif "<|start_header_id|>" in template and "<|eot_id|>" in template: + style = "llama3" + else: + print_rank_0( + "=== WARNING === Cannot auto-inject {% generation %} tags for this chat " + "template. answer_only_loss will not work correctly. Provide a template " + "with {% generation %} tags via the chat_template parameter." + ) + return + + new_template = self._GENERATION_TEMPLATES[style] + self.tokenizer.chat_template = new_template + + # Verify + try: + test_msgs = [ + [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + ] + result = self.tokenizer.apply_chat_template( + test_msgs, + return_dict=True, + return_assistant_tokens_mask=True, + padding=True, + return_tensors="pt", + ) + mask = result.get("assistant_masks", None) + if mask is not None and mask.any(): + print_rank_0( + f"Replaced chat template with {style} generation-tagged version " + f"for answer_only_loss." + ) + return + except Exception: + pass + + # Revert on failure + self.tokenizer.chat_template = old_template + print_rank_0( + f"=== WARNING === Failed to apply {style} generation template. " + "answer_only_loss will not work correctly." + ) + def _process_chat_sample(self, examples: list): tokenized_examples = self.tokenizer.apply_chat_template( examples, @@ -186,6 +329,20 @@ def _process_chat_sample(self, examples: list): input_ids = tokenized_examples["input_ids"] labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) labels[..., :-1] = input_ids[..., 1:] + if self.answer_only_loss: + if "assistant_masks" in tokenized_examples: + assistant_mask = tokenized_examples["assistant_masks"] + if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): + labels[assistant_mask == 0] = IGNORE_TOKEN_ID + else: + # All assistant content truncated or no assistant in batch — mask all + labels[:] = IGNORE_TOKEN_ID + else: + raise ValueError( + "answer_only_loss requires {% generation %} tags in the chat " + "template but assistant_masks was not returned by the tokenizer. " + "Ensure _ensure_generation_tags() ran successfully." + ) tokenized_examples["labels"] = labels return tokenized_examples @@ -211,16 +368,23 @@ def __call__(self, examples): batch.append(text) else: messages = example.get("messages", None) - if messages is None: - conversations = example.get("conversations", None) - if conversations is None: - raise ValueError( - "The sample must in either OpenAI messages format or ShareGPT conversations format." - ) - else: - messages = _sharegpt_to_openai_messages(conversations) + if not messages: + raise ValueError( + "Sample must have a 'messages' field in OpenAI format " + "(list of {role, content} dicts)." + ) + if not any(m.get("role") == "assistant" for m in messages): + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in messages." + ) + continue batch.append(messages) + if not batch: + # All samples skipped — create a dummy batch with all-masked labels + # so the training step produces zero loss without crashing DDP + batch = [[{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]] # type: ignore[list-item] + return self._process_chat_sample(batch) @@ -273,13 +437,10 @@ def __call__(self, examples): for example in examples: messages = example.get("messages", None) if messages is None: - conversations = example.get("conversations", None) - if conversations is None: - raise ValueError( - "The sample must in either OpenAI messages format or ShareGPT conversations format." - ) - else: - messages = _sharegpt_to_openai_messages(conversations) + raise ValueError( + "Sample must have a 'messages' field in OpenAI format " + "(list of {role, content} dicts)." + ) copy_messages = copy.deepcopy(messages) diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml new file mode 100644 index 0000000000..90d161e95d --- /dev/null +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -0,0 +1,55 @@ +# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI. + +# maps to ModelArguments (main.py) +model: + model_name_or_path: + trust_remote_code: false + use_fake_base_for_offline: false + +# maps to DataArguments (main.py) +data: + data_path: + offline_data_path: + +# maps to TrainingArguments (main.py) +training: + # --- commonly modified --- + mode: dflash + output_dir: + num_train_epochs: 10 + per_device_train_batch_size: 1 + learning_rate: 6.0e-4 + warmup_steps: 100 + training_seq_len: 4096 + logging_steps: 100 + save_steps: 5000 + cp_size: 1 + dp_shard_size: 1 + disable_tqdm: true + estimate_ar: false + ar_validate_steps: 0 + + # --- rarely modified --- + do_eval: false + lr_scheduler_type: linear + save_strategy: steps + weight_decay: 0.0 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + ddp_find_unused_parameters: true + ddp_timeout: 1800 + report_to: tensorboard + +# maps to DFlashConfig (modelopt/torch/speculative/config.py). +dflash: + dflash_block_size: 8 + dflash_num_anchors: 512 + dflash_use_torch_compile: false + dflash_self_logit_distillation: true + dflash_loss_decay_factor: 4.0 + dflash_architecture_config: + num_hidden_layers: 5 + # mask_token_id: auto-detected from model vocab (override for specific models) + # sliding_window and layer_types are inherited from base model config automatically diff --git a/tests/gpu/torch/speculative/plugins/test_hf_dflash.py b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..230b67c45d --- /dev/null +++ b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for DFlash speculative decoding plugin. + +These tests require a CUDA GPU. CPU-only tests are in tests/unit/. +""" + +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, + } + return config + + +@pytest.fixture +def dflash_model(): + """Create a tiny DFlash model on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + return model + + +class TestDFlashModuleGPU: + """Test DFlash draft module forward pass on GPU.""" + + def test_dflash_module_forward_shape(self, dflash_model): + """Test that draft module produces correct output shape.""" + model = dflash_model + bsz = 2 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = ( + torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]) + .unsqueeze(0) + .expand(bsz, -1) + .cuda() + ) + + output = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + assert output.shape == (bsz, SEQ_LEN, hidden_size) + + def test_dflash_module_deterministic(self, dflash_model): + """Test that draft module produces identical outputs for same input.""" + model = dflash_model + model.eval() + bsz = 1 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).cuda() + + with torch.no_grad(): + out1 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + out2 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + assert torch.allclose(out1, out2) + + +class TestDFlashTrainingForwardGPU: + """Test DFlash training forward pass end-to-end on GPU.""" + + @pytest.fixture + def model(self): + """Create a tiny DFlash model in training mode on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + model.train() + return model + + def test_training_forward_returns_loss(self, model): + """Test that training forward returns a differentiable loss.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_returns_accuracy(self, model): + """Test that training forward returns train_acc.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "train_acc") + + def test_training_forward_with_labels(self, model): + """Test that labels are used for response-only loss masking.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + # Labels with -100 for first half (masked), real labels for second half + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + labels[:, SEQ_LEN // 2 :] = input_ids[:, SEQ_LEN // 2 :] + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_all_masked_labels(self, model): + """Test that all-masked labels produce zero loss without crashing.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert output.loss.item() == 0.0 + + def test_training_backward(self, model): + """Test that gradients flow to dflash_module.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + output.loss.backward() + + has_grad = False + for name, param in model.dflash_module.named_parameters(): + if param.grad is not None and param.grad.abs().sum() > 0: + has_grad = True + break + assert has_grad, "DFlash module should receive gradients" + + def test_eval_forward_uses_base_model(self, model): + """In eval mode, forward should use base model (not DFlash training).""" + model.eval() + bsz = 1 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + + with torch.no_grad(): + output = model(input_ids=input_ids) + assert output.logits.shape == (bsz, SEQ_LEN, model.config.vocab_size) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..8e8c846583 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,309 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for DFlash speculative decoding plugin. + +GPU-dependent tests (training forward, module forward) are in tests/gpu/. +""" + +import os +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import ( + get_tiny_llama, + tf_modelopt_state_and_output_tester, +) +from transformers import AutoModelForCausalLM + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG +from modelopt.torch.speculative.plugins.hf_dflash import ( + DFlashModule, + HFDFlashModel, + build_target_layer_ids, + create_dflash_attention_mask, + create_dflash_loss_mask, +) + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, # use token 0 as mask for tiny model + } + return config + + +class TestDFlashConvert: + """Test DFlash model conversion.""" + + def test_convert_creates_dflash_model(self): + """Test that convert produces an HFDFlashModel.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert isinstance(model, HFDFlashModel) + + def test_convert_creates_dflash_module(self): + """Test that convert attaches a DFlashModule.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "dflash_module") + assert isinstance(model.dflash_module, DFlashModule) + + def test_convert_freezes_base_model(self): + """Test that base model parameters are frozen after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + for name, param in model.named_parameters(): + if "dflash_module" not in name: + assert not param.requires_grad, f"Base param {name} should be frozen" + + def test_convert_dflash_module_trainable(self): + """Test that DFlash module parameters are trainable after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + dflash_params = [(n, p) for n, p in model.named_parameters() if "dflash_module" in n] + assert len(dflash_params) > 0 + for name, param in dflash_params: + assert param.requires_grad, f"DFlash param {name} should be trainable" + + def test_convert_sets_target_layer_ids(self): + """Test that target layer IDs are set correctly.""" + model = get_tiny_llama(num_hidden_layers=8) + config = _get_dflash_config(num_layers=3) + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "target_layer_ids") + assert len(model.target_layer_ids) == 3 + for lid in model.target_layer_ids: + assert 0 <= lid < 8 + + def test_convert_sets_mask_token_id(self): + """Test that mask_token_id is set from config.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "mask_token_id") + assert model.mask_token_id == 0 + + def test_convert_missing_mask_token_id_errors(self): + """Test that missing mask_token_id raises ValueError for unknown model.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + del config["dflash_architecture_config"]["mask_token_id"] + with pytest.raises(ValueError, match="Cannot auto-detect mask_token_id"): + mtsp.convert(model, [("dflash", config)]) + + +class TestDFlashSaveRestore: + """Test DFlash model save and restore.""" + + def test_save_and_restore(self, tmp_path): + """Test round-trip save/load preserves modelopt state and outputs.""" + mto.enable_huggingface_checkpointing() + model_ref = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model_ref, [("dflash", config)]) + + model_ref.save_pretrained(tmp_path / "modelopt_model") + assert os.path.exists(tmp_path / "modelopt_model/modelopt_state.pth") + + model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model") + assert isinstance(model_test, HFDFlashModel) + tf_modelopt_state_and_output_tester(model_ref, model_test) + + +class TestDFlashMetaRotaryFix: + """Test _apply fixes meta-tensor rotary buffers on .to() calls. + + During checkpoint restore, rotary inv_freq buffers may be on meta device + (they are computed, not saved). _apply should re-create them. + """ + + def test_to_fixes_meta_rotary(self): + """Test that .to() recreates rotary_emb when buffers are on meta device.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + + dflash_mod = model.dflash_module + # Simulate meta buffers (as happens during checkpoint restore) + for name, buf in list(dflash_mod.rotary_emb.named_buffers()): + dflash_mod.rotary_emb._buffers[name] = torch.empty_like(buf, device="meta") + + assert any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) + + # .to() triggers _apply which should fix meta buffers + dflash_mod.to("cpu") + + assert not any(b.is_meta for b in dflash_mod.rotary_emb.buffers()) + + def test_to_noop_when_no_meta(self): + """Test that .to() does not recreate rotary_emb when buffers are normal.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + + dflash_mod = model.dflash_module + rotary_id_before = id(dflash_mod.rotary_emb) + dflash_mod.to("cpu") + assert id(dflash_mod.rotary_emb) == rotary_id_before + + +class TestDFlashAttentionMask: + """Test DFlash attention mask construction.""" + + def test_mask_shape(self): + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + assert mask.shape == (1, 1, SEQ_LEN, 2 * SEQ_LEN) + + def test_mask_context_strictly_previous_blocks(self): + """Context (left half): block B can only see blocks 0..B-1.""" + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] + ctx_mask = mask_2d[:, :8] + assert (ctx_mask[:4, :] < 0).all() + assert (ctx_mask[4:8, :4] == 0).all() + assert (ctx_mask[4:8, 4:8] < 0).all() + + def test_mask_noise_causal_within_block(self): + """Noise (right half): reverse-causal within same block (j >= i).""" + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + noise_mask = mask[0, 0, :, 8:] + assert (noise_mask[0, :4] == 0).all() + assert (noise_mask[3, :3] < 0).all() + assert noise_mask[3, 3] == 0 + assert (noise_mask[4:8, :4] < 0).all() + + def test_mask_values_are_zero_or_neg_inf(self): + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + unique_vals = mask.unique() + assert len(unique_vals) == 2 + assert 0.0 in unique_vals + assert unique_vals.min() == torch.finfo(torch.float32).min + + +class TestDFlashLossMask: + """Test DFlash loss mask construction.""" + + def test_loss_mask_shape(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert mask.shape == (SEQ_LEN,) + + def test_loss_mask_excludes_block_zero(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert (mask[:BLOCK_SIZE] == 0).all() + + def test_loss_mask_excludes_block_starts(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for i in range(0, SEQ_LEN, BLOCK_SIZE): + assert mask[i] == 0 + + def test_loss_mask_includes_non_start_positions(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for b in range(1, SEQ_LEN // BLOCK_SIZE): + for offset in range(1, BLOCK_SIZE): + pos = b * BLOCK_SIZE + offset + assert mask[pos] == 1 + + def test_loss_mask_count(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + num_blocks = SEQ_LEN // BLOCK_SIZE + expected = (num_blocks - 1) * (BLOCK_SIZE - 1) + assert mask.sum().item() == expected + + +class TestBuildTargetLayerIds: + """Test target layer selection.""" + + def test_single_draft_layer(self): + ids = build_target_layer_ids(32, 1) + assert len(ids) == 1 + assert ids[0] == 16 + + def test_multiple_draft_layers(self): + ids = build_target_layer_ids(36, 5) + assert len(ids) == 5 + assert ids == sorted(ids) + assert all(1 <= lid <= 33 for lid in ids) + + def test_layer_ids_no_duplicates(self): + ids = build_target_layer_ids(32, 5) + assert len(set(ids)) == 5 + + def test_layer_ids_match_zlab(self): + """Test layer IDs match z-lab reference for Qwen3-8B (36 layers, 5 draft).""" + ids = build_target_layer_ids(36, 5) + assert ids == [1, 9, 17, 25, 33] + + +class TestDFlashSlidingWindow: + """Test sliding window attention support.""" + + def test_sliding_window_from_config(self): + """Test DFlashAttention reads sliding_window from config.layer_types.""" + from transformers import PretrainedConfig + + from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention + + config = PretrainedConfig( + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=16, + rms_norm_eps=1e-6, + attention_bias=False, + attention_dropout=0.0, + layer_types=["full_attention", "sliding_attention"], + sliding_window=256, + _attn_implementation="sdpa", + ) + attn_full = DFlashAttention(config, layer_idx=0) + attn_sliding = DFlashAttention(config, layer_idx=1) + assert attn_full.sliding_window is None + assert attn_sliding.sliding_window == 256 + + def test_no_sliding_window_without_config(self): + """Test DFlashAttention defaults to no sliding window.""" + from transformers import PretrainedConfig + + from modelopt.torch.speculative.plugins.hf_dflash import DFlashAttention + + config = PretrainedConfig( + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=16, + rms_norm_eps=1e-6, + attention_bias=False, + attention_dropout=0.0, + _attn_implementation="sdpa", + ) + attn = DFlashAttention(config, layer_idx=0) + assert attn.sliding_window is None diff --git a/tools/launcher/common/specdec/ar_eval_mtbench.sh b/tools/launcher/common/specdec/ar_eval_mtbench.sh new file mode 100644 index 0000000000..4062e6b2fe --- /dev/null +++ b/tools/launcher/common/specdec/ar_eval_mtbench.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MT-Bench AR evaluation using scripts/ar_validate.py. +# Finds the latest checkpoint and runs per-category AR validation. +# +# Args are passed directly to ar_validate.py (--model_path, --osl, --steps, etc.) +# If --model_path is not provided, auto-detects from --ckpt_dir. + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt 2>&1 | tail -3 + +trap 'error_handler $0 $LINENO' ERR + +# Parse --ckpt_dir to find latest checkpoint (ar_validate.py expects --model_path) +ARGS=() +CKPT_DIR="" +while [ $# -gt 0 ]; do + case "$1" in + --ckpt_dir) shift; CKPT_DIR="$1" ;; + *) ARGS+=("$1") ;; + esac + shift +done + +# Auto-detect model_path from ckpt_dir if not explicitly provided +MODEL_PATH="" +if [ -n "$CKPT_DIR" ]; then + # Find latest checkpoint subdir + LAST_CKPT=$(ls -d ${CKPT_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) + if [ -f "${CKPT_DIR}/model.safetensors" ]; then + MODEL_PATH="${CKPT_DIR}" + elif [ -n "$LAST_CKPT" ]; then + MODEL_PATH="${LAST_CKPT}" + fi + echo "Auto-detected model_path: ${MODEL_PATH}" +fi + +if [ -z "$MODEL_PATH" ]; then + echo "ERROR: No checkpoint found. Provide --ckpt_dir or --model_path." + exit 1 +fi + +CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/ar_validate.py \ + --model_path "${MODEL_PATH}" \ + --per_category \ + "${ARGS[@]}" + +report_result "PASS: MT-Bench AR evaluation" diff --git a/tools/launcher/common/specdec/dflash_online_training.sh b/tools/launcher/common/specdec/dflash_online_training.sh new file mode 100644 index 0000000000..654ba0185d --- /dev/null +++ b/tools/launcher/common/specdec/dflash_online_training.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash online training script for the ModelOpt Launcher. +# Trains a DFlash draft model using accelerate launch + main.py --config. +# +# All training config comes from the YAML recipe (--config) and OmegaConf overrides. +# All args are passed directly to main.py (--config + key=value overrides). +# +# Multi-node env vars (set by Slurm or user): +# NUM_NODES — number of nodes (default: 1) +# HEAD_NODE_IP — head node IP (auto-detected if not set) +# +# Usage from YAML: +# script: common/dflash/online_training.sh +# args: +# - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml +# - model.model_name_or_path=/hf-local/Qwen/Qwen3-8B +# - data.data_path=/path/to/data.jsonl +# - training.output_dir=/scratchspace/dflash +# environment: +# - NUM_NODES: "8" + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt +pip install huggingface-hub>=1.2.1 +export PATH=$PATH:/workspace/.local/bin + +################################################################################################### + +trap 'error_handler $0 $LINENO' ERR + +# Auto-detect head node IP for multi-node training +NUM_NODES=${NUM_NODES:-1} +if [ -z "$HEAD_NODE_IP" ] && [[ "$NUM_NODES" != "1" ]]; then + HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) + HEAD_NODE_IP=${HEAD_NODE_IP:-$SLURM_LAUNCH_NODE_IPADDR} + if [ -z "$HEAD_NODE_IP" ] && [ -n "$SLURM_JOB_NODELIST" ]; then + HEAD_NODE_IP=$(python3 -c " +import socket, re, os +nl = os.environ.get('SLURM_JOB_NODELIST', '') +m = re.match(r'([a-zA-Z0-9-]+?)(?:\[(\d+))?', nl) +if m: + host = m.group(1) + (m.group(2) or '') + try: + print(socket.gethostbyname(host)) + except: + print(host) +" 2>/dev/null) + fi + if [ -z "$HEAD_NODE_IP" ] && [ "${SLURM_PROCID:-0}" = "0" ]; then + HEAD_NODE_IP=$(hostname -I 2>/dev/null | awk '{print $1}') + fi + export HEAD_NODE_IP + echo "Auto-detected HEAD_NODE_IP: ${HEAD_NODE_IP}" +fi + +# Build accelerate launch command +MAIN_PY=modules/Model-Optimizer/examples/speculative_decoding/main.py + +if [[ "$NUM_NODES" != "1" ]]; then + if [ -z "$HEAD_NODE_IP" ]; then + echo "ERROR: HEAD_NODE_IP is empty. Cannot launch multi-node training." + exit 1 + fi + GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} + TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) + echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" + MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ + --num_machines $NUM_NODES \ + --machine_rank $SLURM_PROCID \ + --rdzv_backend c10d \ + --main_process_ip $HEAD_NODE_IP \ + --main_process_port 29500" +else + TOTAL_GPU=$(python3 -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (single node)" + MULTI_NODE_ARGS="" +fi + +export TOKENIZERS_PARALLELISM=False + +set -x +start_time=$(date +%s) +accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS $MAIN_PY "$@" +echo "Training time: $(( $(date +%s) - start_time )) seconds" +set +x + +# Export last checkpoint to deployment format (rank 0 only, single GPU) +if [ "${SLURM_PROCID:-0}" = "0" ]; then + OUTPUT_DIR=$(python3 -c " +import sys +for arg in sys.argv[1:]: + if arg.startswith('training.output_dir='): + print(arg.split('=', 1)[1]) + break +" "$@") + + if [ -n "$OUTPUT_DIR" ]; then + LAST_CKPT=$(ls -d ${OUTPUT_DIR}/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1) + if [ -n "$LAST_CKPT" ]; then + STEP=$(basename "$LAST_CKPT" | sed 's/checkpoint-//') + EXPORT_DIR="${OUTPUT_DIR}/exported-checkpoint-${STEP}" + echo "=== Exporting last checkpoint: ${LAST_CKPT} → ${EXPORT_DIR} ===" + CUDA_VISIBLE_DEVICES=0 python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ + --model_path "${LAST_CKPT}" \ + --export_path "${EXPORT_DIR}" + echo "Export contents:" + ls -lh "${EXPORT_DIR}/" + else + echo "No checkpoints found in ${OUTPUT_DIR}, skipping export" + fi + fi +fi diff --git a/tools/launcher/common/specdec/vllm_smoke_test.sh b/tools/launcher/common/specdec/vllm_smoke_test.sh new file mode 100644 index 0000000000..ccd59cf094 --- /dev/null +++ b/tools/launcher/common/specdec/vllm_smoke_test.sh @@ -0,0 +1,110 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Quick vLLM smoke test for speculative decoding (EAGLE3, DFlash, etc.). +# Launches server, sends a few test prompts, verifies responses, and shuts down. +# +# Required env vars: +# HF_MODEL_CKPT — target model path +# DRAFT_MODEL — draft model path +# +# Optional env vars: +# SPEC_METHOD — speculative method: "eagle", "dflash", etc. (default: "eagle") +# NUM_SPEC_TOKENS — number of speculative tokens (default: 15) +# TP_SIZE — tensor parallel size (default: 1) +# VLLM_PORT — server port (default: 8000) + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh 2>/dev/null || true + +cleanup() { kill $SERVER_PID 2>/dev/null; sleep 2; kill -9 $SERVER_PID 2>/dev/null; } +trap cleanup EXIT + +MODEL=${HF_MODEL_CKPT} +DRAFT=${DRAFT_MODEL} +METHOD=${SPEC_METHOD:-eagle} +NUM_SPEC=${NUM_SPEC_TOKENS:-15} +PORT=${VLLM_PORT:-8000} +TP=${TP_SIZE:-1} + +echo "=== vLLM Speculative Decoding Smoke Test ===" +echo "Method: ${METHOD}" +echo "Target: ${MODEL}" +echo "Draft: ${DRAFT}" +echo "Spec tokens: ${NUM_SPEC}, TP: ${TP}" + +# Build speculative config +SPEC_CONFIG="{\"method\": \"${METHOD}\", \"model\": \"${DRAFT}\", \"num_speculative_tokens\": ${NUM_SPEC}}" + +# Start vLLM server +vllm serve ${MODEL} \ + --speculative-config "${SPEC_CONFIG}" \ + --max-num-batched-tokens 32768 \ + --tensor-parallel-size ${TP} \ + --port ${PORT} \ + & +SERVER_PID=$! + +# Wait for server +echo "Waiting for vLLM server..." +for i in $(seq 1 180); do + if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then + echo "Server ready after ${i}s" + break + fi + if ! kill -0 $SERVER_PID 2>/dev/null; then + echo "ERROR: Server died"; wait $SERVER_PID; exit 1 + fi + sleep 1 +done + +if ! curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then + echo "ERROR: Server timeout"; exit 1 +fi + +# Run quick test prompts +echo "" +echo "=== Test Prompts ===" +PASS=0 +FAIL=0 +for PROMPT in \ + "What is 2+3? Answer with just the number." \ + "Write a haiku about mountains." \ + "Explain what a CPU is in one sentence."; do + RESPONSE=$(curl -s http://localhost:${PORT}/v1/completions \ + -H "Content-Type: application/json" \ + -d "{\"model\": \"${MODEL}\", \"prompt\": \"${PROMPT}\", \"max_tokens\": 64, \"temperature\": 0}" \ + | python3 -c "import json,sys; r=json.load(sys.stdin); t=r.get('choices',[{}])[0].get('text',''); u=r.get('usage',{}); print(f'{t.strip()[:100]}|||{u.get(\"completion_tokens\",0)}')" 2>/dev/null) + TEXT=$(echo "$RESPONSE" | cut -d'|||' -f1) + TOKENS=$(echo "$RESPONSE" | cut -d'|||' -f2) + if [ -n "$TEXT" ] && [ "$TOKENS" -gt 0 ] 2>/dev/null; then + echo " PASS: \"${PROMPT}\" → ${TOKENS} tokens" + PASS=$((PASS + 1)) + else + echo " FAIL: \"${PROMPT}\" → empty or error" + FAIL=$((FAIL + 1)) + fi +done + +echo "" +echo "Results: ${PASS} passed, ${FAIL} failed" + +if [ $FAIL -gt 0 ]; then + echo "ERROR: Some prompts failed" + exit 1 +fi + +echo "Done" diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml new file mode 100644 index 0000000000..a6e44af244 --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -0,0 +1,55 @@ +# DFlash online speculative decoding training for Qwen3-8B. +# +# 2-step pipeline: +# task_0: Online DFlash training (8 nodes, 64 GPUs) +# task_1: MT-Bench per-category AR evaluation (1 GPU) +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes + +job_name: Qwen3-8B_DFlash_online +pipeline: + global_vars: + hf_model: /hf-local/Qwen/Qwen3-8B + + # Step 1: Online DFlash training + task_0: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - data.data_path=/hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-1K.jsonl + - training.output_dir=/scratchspace/dflash_bs16 + - training.num_train_epochs=1 + - training.training_seq_len=4096 + - training.save_steps=5000 + - training.logging_steps=1000 + - training.disable_tqdm=true + - training.answer_only_loss=true + - dflash.dflash_block_size=16 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=7 + - dflash.dflash_architecture_config.mask_token_id=151669 + - dflash.dflash_architecture_config.num_hidden_layers=5 + environment: + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + + # Step 2: MT-Bench per-category AR evaluation + task_1: + script: common/specdec/ar_eval_mtbench.sh + args: + - --ckpt_dir /scratchspace/dflash_bs16 + - --osl 512 + - --steps 15 + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 diff --git a/uv.lock b/uv.lock index d890e361cb..2b2b274e13 100644 --- a/uv.lock +++ b/uv.lock @@ -20,9 +20,6 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] -[manifest] -overrides = [{ name = "torch", marker = "sys_platform == 'never'" }] - [[package]] name = "accelerate" version = "1.13.0" @@ -35,7 +32,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } wheels = [ @@ -480,6 +477,21 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/54/27/01d9078a77b9e31b79b9716e66ca4db74f4744c5232bcb3e8769395c4280/cppimport-22.8.2.tar.gz", hash = "sha256:bbb4957102db41bc99ad72c233bce92f9d1fd91be352fc07878c4361033a401f", size = 26635, upload-time = "2022-08-02T16:50:36.872Z" } +[[package]] +name = "cuda-bindings" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/45/e7/b47792cc2d01c7e1d37c32402182524774dadd2d26339bd224e0e913832e/cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c912a3d9e6b6651853eed8eed96d6800d69c08e94052c292fec3f282c5a817c9", size = 12210593, upload-time = "2025-10-21T14:51:36.574Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, + { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, + { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, +] + [[package]] name = "cuda-pathfinder" version = "1.4.3" @@ -554,7 +566,7 @@ dependencies = [ { name = "psutil", marker = "sys_platform != 'win32'" }, { name = "py-cpuinfo", marker = "sys_platform != 'win32'" }, { name = "pydantic", marker = "sys_platform != 'win32'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch", marker = "sys_platform != 'win32'" }, { name = "tqdm", marker = "sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/11/46b9eb3806ca7a5e9bdddb7e873855a2d59a9f87f0675ae8231678d98434/deepspeed-0.18.8.tar.gz", hash = "sha256:e4e051a144b0c74270c46e4970139f9a86a61ff26959c5e463000c4a93b99304", size = 1647226, upload-time = "2026-03-13T18:49:48.568Z" } @@ -1311,7 +1323,9 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version < '3.11' and sys_platform == 'darwin')", + "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'win32'", "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } @@ -1324,12 +1338,18 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version >= '3.13' and sys_platform == 'darwin')", "(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.12.*' and sys_platform == 'darwin')", "(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.11.*' and sys_platform == 'darwin')", "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1526,6 +1546,108 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/a7/b35835e278c18b85206834b3aa3abe68e77a98769c59233d1f6300284781/numpy-2.4.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:4b42639cdde6d24e732ff823a3fa5b701d8acad89c4142bc1d0bd6dc85200ba5", size = 12504685, upload-time = "2026-03-09T07:58:50.525Z" }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + [[package]] name = "nvidia-ml-py" version = "13.595.45" @@ -1554,7 +1676,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "setuptools" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, ] @@ -1582,6 +1704,7 @@ all = [ { name = "peft" }, { name = "polygraphy" }, { name = "sentencepiece" }, + { name = "tiktoken" }, { name = "transformers" }, { name = "wonderwords" }, ] @@ -1589,6 +1712,7 @@ dev = [ { name = "accelerate" }, { name = "autodoc-pydantic" }, { name = "bandit", extra = ["toml"] }, + { name = "coverage", extra = ["toml"] }, { name = "cppimport" }, { name = "cupy-cuda12x", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin'" }, { name = "datasets" }, @@ -1625,6 +1749,7 @@ dev = [ { name = "sphinx-inline-tabs" }, { name = "sphinx-rtd-theme" }, { name = "sphinx-togglebutton" }, + { name = "tiktoken" }, { name = "timm" }, { name = "torch-geometric" }, { name = "torchprofile" }, @@ -1652,6 +1777,7 @@ dev-lint = [ { name = "ruff" }, ] dev-test = [ + { name = "coverage", extra = ["toml"] }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-instafail" }, @@ -1672,6 +1798,7 @@ hf = [ { name = "nltk" }, { name = "peft" }, { name = "sentencepiece" }, + { name = "tiktoken" }, { name = "transformers" }, { name = "wonderwords" }, ] @@ -1697,6 +1824,7 @@ requires-dist = [ { name = "accelerate", marker = "extra == 'hf'", specifier = ">=1.0.0" }, { name = "autodoc-pydantic", marker = "extra == 'dev-docs'", specifier = ">=2.1.0" }, { name = "bandit", extras = ["toml"], marker = "extra == 'dev-lint'", specifier = "==1.7.9" }, + { name = "coverage", extras = ["toml"], marker = "extra == 'dev-test'", specifier = ">=7.13.0" }, { name = "cppimport", marker = "extra == 'onnx'" }, { name = "cupy-cuda12x", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and extra == 'onnx'" }, { name = "datasets", marker = "extra == 'hf'", specifier = ">=3.0.0" }, @@ -1748,6 +1876,7 @@ requires-dist = [ { name = "sphinx-inline-tabs", marker = "extra == 'dev-docs'", specifier = ">=2023.4.21" }, { name = "sphinx-rtd-theme", marker = "extra == 'dev-docs'", specifier = "~=3.0.0" }, { name = "sphinx-togglebutton", marker = "extra == 'dev-docs'", specifier = ">=0.3.2" }, + { name = "tiktoken", marker = "extra == 'hf'" }, { name = "timm", marker = "extra == 'dev-test'" }, { name = "torch", specifier = ">=2.6" }, { name = "torch-geometric", marker = "extra == 'dev-test'" }, @@ -1756,11 +1885,43 @@ requires-dist = [ { name = "tox", marker = "extra == 'dev-test'", specifier = ">4.18" }, { name = "tox-current-env", marker = "extra == 'dev-test'", specifier = ">=0.0.12" }, { name = "tqdm" }, - { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53,<5.0" }, + { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.56,<5.0" }, { name = "wonderwords", marker = "extra == 'hf'" }, ] provides-extras = ["onnx", "hf", "dev-lint", "dev-docs", "dev-test", "all", "dev"] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu12" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + [[package]] name = "omegaconf" version = "2.3.0" @@ -2157,7 +2318,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -3395,6 +3556,53 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "tiktoken" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "regex" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/b3/2cb7c17b6c4cf8ca983204255d3f1d95eda7213e247e6947a0ee2c747a2c/tiktoken-0.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3de02f5a491cfd179aec916eddb70331814bd6bf764075d39e21d5862e533970", size = 1051991, upload-time = "2025-10-06T20:21:34.098Z" }, + { url = "https://files.pythonhosted.org/packages/27/0f/df139f1df5f6167194ee5ab24634582ba9a1b62c6b996472b0277ec80f66/tiktoken-0.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b6cfb6d9b7b54d20af21a912bfe63a2727d9cfa8fbda642fd8322c70340aad16", size = 995798, upload-time = "2025-10-06T20:21:35.579Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5d/26a691f28ab220d5edc09b9b787399b130f24327ef824de15e5d85ef21aa/tiktoken-0.12.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:cde24cdb1b8a08368f709124f15b36ab5524aac5fa830cc3fdce9c03d4fb8030", size = 1129865, upload-time = "2025-10-06T20:21:36.675Z" }, + { url = "https://files.pythonhosted.org/packages/b2/94/443fab3d4e5ebecac895712abd3849b8da93b7b7dec61c7db5c9c7ebe40c/tiktoken-0.12.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6de0da39f605992649b9cfa6f84071e3f9ef2cec458d08c5feb1b6f0ff62e134", size = 1152856, upload-time = "2025-10-06T20:21:37.873Z" }, + { url = "https://files.pythonhosted.org/packages/54/35/388f941251b2521c70dd4c5958e598ea6d2c88e28445d2fb8189eecc1dfc/tiktoken-0.12.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6faa0534e0eefbcafaccb75927a4a380463a2eaa7e26000f0173b920e98b720a", size = 1195308, upload-time = "2025-10-06T20:21:39.577Z" }, + { url = "https://files.pythonhosted.org/packages/f8/00/c6681c7f833dd410576183715a530437a9873fa910265817081f65f9105f/tiktoken-0.12.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:82991e04fc860afb933efb63957affc7ad54f83e2216fe7d319007dab1ba5892", size = 1255697, upload-time = "2025-10-06T20:21:41.154Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d2/82e795a6a9bafa034bf26a58e68fe9a89eeaaa610d51dbeb22106ba04f0a/tiktoken-0.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:6fb2995b487c2e31acf0a9e17647e3b242235a20832642bb7a9d1a181c0c1bb1", size = 879375, upload-time = "2025-10-06T20:21:43.201Z" }, + { url = "https://files.pythonhosted.org/packages/de/46/21ea696b21f1d6d1efec8639c204bdf20fde8bafb351e1355c72c5d7de52/tiktoken-0.12.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e227c7f96925003487c33b1b32265fad2fbcec2b7cf4817afb76d416f40f6bb", size = 1051565, upload-time = "2025-10-06T20:21:44.566Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d9/35c5d2d9e22bb2a5f74ba48266fb56c63d76ae6f66e02feb628671c0283e/tiktoken-0.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c06cf0fcc24c2cb2adb5e185c7082a82cba29c17575e828518c2f11a01f445aa", size = 995284, upload-time = "2025-10-06T20:21:45.622Z" }, + { url = "https://files.pythonhosted.org/packages/01/84/961106c37b8e49b9fdcf33fe007bb3a8fdcc380c528b20cc7fbba80578b8/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f18f249b041851954217e9fd8e5c00b024ab2315ffda5ed77665a05fa91f42dc", size = 1129201, upload-time = "2025-10-06T20:21:47.074Z" }, + { url = "https://files.pythonhosted.org/packages/6a/d0/3d9275198e067f8b65076a68894bb52fd253875f3644f0a321a720277b8a/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47a5bc270b8c3db00bb46ece01ef34ad050e364b51d406b6f9730b64ac28eded", size = 1152444, upload-time = "2025-10-06T20:21:48.139Z" }, + { url = "https://files.pythonhosted.org/packages/78/db/a58e09687c1698a7c592e1038e01c206569b86a0377828d51635561f8ebf/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:508fa71810c0efdcd1b898fda574889ee62852989f7c1667414736bcb2b9a4bd", size = 1195080, upload-time = "2025-10-06T20:21:49.246Z" }, + { url = "https://files.pythonhosted.org/packages/9e/1b/a9e4d2bf91d515c0f74afc526fd773a812232dd6cda33ebea7f531202325/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1af81a6c44f008cba48494089dd98cccb8b313f55e961a52f5b222d1e507967", size = 1255240, upload-time = "2025-10-06T20:21:50.274Z" }, + { url = "https://files.pythonhosted.org/packages/9d/15/963819345f1b1fb0809070a79e9dd96938d4ca41297367d471733e79c76c/tiktoken-0.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:3e68e3e593637b53e56f7237be560f7a394451cb8c11079755e80ae64b9e6def", size = 879422, upload-time = "2025-10-06T20:21:51.734Z" }, + { url = "https://files.pythonhosted.org/packages/a4/85/be65d39d6b647c79800fd9d29241d081d4eeb06271f383bb87200d74cf76/tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8", size = 1050728, upload-time = "2025-10-06T20:21:52.756Z" }, + { url = "https://files.pythonhosted.org/packages/4a/42/6573e9129bc55c9bf7300b3a35bef2c6b9117018acca0dc760ac2d93dffe/tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b", size = 994049, upload-time = "2025-10-06T20:21:53.782Z" }, + { url = "https://files.pythonhosted.org/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37", size = 1129008, upload-time = "2025-10-06T20:21:54.832Z" }, + { url = "https://files.pythonhosted.org/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad", size = 1152665, upload-time = "2025-10-06T20:21:56.129Z" }, + { url = "https://files.pythonhosted.org/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5", size = 1194230, upload-time = "2025-10-06T20:21:57.546Z" }, + { url = "https://files.pythonhosted.org/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" }, + { url = "https://files.pythonhosted.org/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" }, + { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" }, + { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" }, + { url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" }, + { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" }, + { url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" }, + { url = "https://files.pythonhosted.org/packages/8e/32/45d02e2e0ea2be3a9ed22afc47d93741247e75018aac967b713b2941f8ea/tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697", size = 879117, upload-time = "2025-10-06T20:22:08.418Z" }, + { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" }, + { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" }, + { url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" }, + { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" }, + { url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" }, + { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" }, + { url = "https://files.pythonhosted.org/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" }, +] + [[package]] name = "timm" version = "1.0.25" @@ -3403,7 +3611,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d7/2c/593109822fe735e637382aca6640c1102c19797f7791f1fd1dab2d6c3cb1/timm-1.0.25.tar.gz", hash = "sha256:47f59fc2754725735cc81bb83bcbfce5bec4ebd5d4bb9e69da57daa92fcfa768", size = 2414743, upload-time = "2026-02-23T16:49:00.137Z" } @@ -3491,15 +3699,63 @@ name = "torch" version = "2.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, + { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" }, + { url = "https://files.pythonhosted.org/packages/76/bb/d820f90e69cda6c8169b32a0c6a3ab7b17bf7990b8f2c680077c24a3c14c/torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d", size = 79411450, upload-time = "2026-01-21T16:25:30.692Z" }, + { url = "https://files.pythonhosted.org/packages/78/89/f5554b13ebd71e05c0b002f95148033e730d3f7067f67423026cc9c69410/torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4", size = 145992610, upload-time = "2026-01-21T16:25:26.327Z" }, + { url = "https://files.pythonhosted.org/packages/ae/30/a3a2120621bf9c17779b169fc17e3dc29b230c29d0f8222f499f5e159aa8/torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763", size = 915607863, upload-time = "2026-01-21T16:25:06.696Z" }, + { url = "https://files.pythonhosted.org/packages/6f/3d/c87b33c5f260a2a8ad68da7147e105f05868c281c63d65ed85aa4da98c66/torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd", size = 113723116, upload-time = "2026-01-21T16:25:21.916Z" }, + { url = "https://files.pythonhosted.org/packages/61/d8/15b9d9d3a6b0c01b883787bd056acbe5cc321090d4b216d3ea89a8fcfdf3/torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b", size = 79423461, upload-time = "2026-01-21T16:24:50.266Z" }, + { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, + { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, + { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" }, + { url = "https://files.pythonhosted.org/packages/c9/5c/dee910b87c4d5c0fcb41b50839ae04df87c1cfc663cf1b5fca7ea565eeaa/torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294", size = 79498198, upload-time = "2026-01-21T16:24:34.704Z" }, + { url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, + { url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, + { url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, + { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, + { url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, + { url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, + { url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, + { url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, +] [[package]] name = "torch-geometric" @@ -3529,7 +3785,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6f/36/574c0c46e818533b78b3c09505211162918188325ab4165ef11a3f295755/torchprofile-0.0.4.tar.gz", hash = "sha256:96b6da17d752a06b02977e078aea95614893b31d4117dd5dcd081f30ce65611b", size = 4557, upload-time = "2021-06-22T04:58:03.592Z" } @@ -3545,7 +3801,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pillow" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/50/ae/cbf727421eb73f1cf907fbe5788326a08f111b3f6b6ddca15426b53fec9a/torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a95c47abb817d4e90ea1a8e57bd0d728e3e6b533b3495ae77d84d883c4d11f56", size = 1874919, upload-time = "2026-01-21T16:27:47.617Z" }, @@ -3638,6 +3894,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" }, ] +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0"