Frédéric Berdoz · Peer Rheinboldt · Roger Wattenhofer
Accepted at FPI @ NeurIPS 2025, SPIGM @ NeurIPS 2025 and at AAAI 2026
SD² is a framework for steering pretrained drafters during speculative decoding to further increase alignment between drafter and verifier. This repo contains:
- Evaluation tools
- Training scripts
- Synthetic dataset generation
- Configuration files and checkpoints used in the paper can be found in this Huggingface collection
- Create a new python environment with python version 3.12
python -m venv ./.venv
source ./.venv/bin/activate- Install dependencies
pip install -r requirements.txt- Note that you will have to login to Hugging Face and have access to the Llama Models. Additionally training requires Weights & Biases
- Open
eval.pyand add the configurations you'd like to try out toconfigs - Run the script
python eval.py --out_file 'eval_results.json'
# We recommend piping output to a seperate file for later Evaluation
python eval.py --pattern "llama" --out_file "llama_results.json" > llama_out.log- Update `configs/experiment.yaml with correct data
- Generate a synthetic dataset
python create_synth_ds.py --bsz 128 --target_len 256 --use_ultrachat_prompts --model_name 'meta-llama/llama-3.1-8b-instruct'- Run training script
python main.py fit --config configs/experiment.yaml \
--data.path ./data/synthetic/llama-3.1-8b-instruct-ultrachat-prompts \
--data.bsz=12 \
--data.n_val 3000 \
--trainer.max_epochs 6 \
--trainer.val_check_interval 2000 \
--trainer.accumulate_grad_batches 2 \
--model.method guided-drafter \
--model.lr_start 0.00001 \
--model.lr_end 0.000001 \
--model.warmup_steps 1000 \
--model.loss_method kl \
--model.drafter=meta-llama/llama-3.2-1b-instruct \
--model.finetune_drafter full \
--model.guide_method merged \
--model.d_layer all \
--model.verifier=meta-llama/llama-3.1-8b-instruct \
--model.v_layer '[3,16,29]'