Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/configs/nvidia-master.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2537,6 +2537,41 @@ dsv4-fp4-b300-vllm:
- { tp: 4, ep: 4, dp-attn: true, conc-start: 256, conc-end: 1024 }
- { tp: 8, ep: 8, dp-attn: true, conc-start: 2048, conc-end: 2048 }

# DeepSeek-V4-Pro on B300 with vLLM + MTP speculative decoding. The
# deepseekv4-cu130 image registers DeepSeekV4MTPModel and exposes
# --speculative-config '{"method":"deepseek_mtp",...}'; the script
# defaults to num_speculative_tokens=1 (DSv4 publishes
# num_nextn_predict_layers=1). Bench passes --dsv4 so prompts are
# routed through encoding_dsv4.py -- required for honest MTP
# acceptance numbers (random-token prompts silently regress).
#
# Two CONC bands (TP-only, mirrors dsv4-fp4-b300-sglang-mtp coverage):
# A: TP=8 -- conc 1-8 low-conc, where MTP wins most
# B: TP=4 -- conc 4-32 mid-batch
# DP-attn (TP=4 ep=4 dp-attn / TP=8 ep=8 dp-attn) is wired in
# benchmarks/single_node/dsv4_fp4_b300_vllm_mtp.sh via DP_ATTENTION=true
# but is excluded from the initial sweep until the TP-only bands have
# published numbers.
dsv4-fp4-b300-vllm-mtp:
image: vllm/vllm-openai:deepseekv4-cu130
model: deepseek-ai/DeepSeek-V4-Pro
model-prefix: dsv4
runner: b300
precision: fp4
framework: vllm
multinode: false
seq-len-configs:
- isl: 1024
osl: 1024
search-space:
- { tp: 8, conc-start: 1, conc-end: 8, spec-decoding: mtp }
- { tp: 4, conc-start: 4, conc-end: 32, spec-decoding: mtp }
- isl: 8192
osl: 1024
search-space:
- { tp: 8, conc-start: 1, conc-end: 8, spec-decoding: mtp }
- { tp: 4, conc-start: 4, conc-end: 32, spec-decoding: mtp }

qwen3.5-fp8-h200-sglang:
image: lmsysorg/sglang:v0.5.9-cu129-amd64
model: Qwen/Qwen3.5-397B-A17B-FP8
Expand Down
139 changes: 139 additions & 0 deletions benchmarks/single_node/dsv4_fp4_b300_vllm_mtp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/usr/bin/env bash

# DeepSeek-V4-Pro B300 single-node vLLM recipe with MTP (Multi-Token
# Prediction) speculative decoding. Mirrors the dsv4_fp4_b300_vllm.sh
# split: TP mode (dp-attn=false) runs without expert parallel; DP mode
# (dp-attn=true) enables expert parallel (EP_SIZE=TP value = DP size).
#
# MTP plumbing: the deepseekv4-cu130 image registers DeepSeekV4MTPModel
# (vllm/model_executor/models/deepseek_v4_mtp.py) and remaps
# model_type=deepseek_v4 -> deepseek_mtp inside SpeculativeConfig
# (vllm/config/speculative.py). DSv4-Pro publishes
# num_nextn_predict_layers=1 so the MTP head is invoked once per
# spec step; vLLM accepts num_speculative_tokens > n_predict as long as
# it's divisible (the head is reused). We start with k=1 (matches the
# DSv4 b200 dynamo sglang MTP entry); bump SPEC_TOKENS to sweep.

source "$(dirname "$0")/../benchmark_lib.sh"

check_env_vars \
MODEL \
TP \
DP_ATTENTION \
CONC \
ISL \
OSL \
MAX_MODEL_LEN \
RANDOM_RANGE_RATIO \
RESULT_FILENAME

if [[ -n "$SLURM_JOB_ID" ]]; then
echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME"
fi

nvidia-smi

hf download "$MODEL"

SERVER_LOG=/workspace/server.log
PORT=${PORT:-8888}

# DeepSeek-V4-Pro weights are large; engine startup can exceed the default
# 600s. Give it an hour to load.
export VLLM_ENGINE_READY_TIMEOUT_S=3600

PARALLEL_ARGS=(--tensor-parallel-size "$TP" --data-parallel-size 1)
if [ "${DP_ATTENTION}" = "true" ]; then
PARALLEL_ARGS=(--tensor-parallel-size 1 --data-parallel-size "$TP")
fi

EP_ARGS=()
if [ "${EP_SIZE:-1}" -gt 1 ]; then
EP_ARGS=(--enable-expert-parallel)
fi

if [ "${DP_ATTENTION}" = "true" ]; then
MAX_NUM_BATCHED_TOKENS=2048
else
MAX_NUM_BATCHED_TOKENS=$(( ISL * 2 ))
fi

BENCHMARK_MAX_MODEL_LEN="$MAX_MODEL_LEN"
if [ "$ISL" -eq 1024 ] && [ "$OSL" -eq 1024 ]; then
BENCHMARK_MAX_MODEL_LEN=4096
fi

if [ "${EVAL_ONLY}" = "true" ]; then
EVAL_MAX_MODEL_LEN=$(compute_eval_context_length "$MODEL" "$BENCHMARK_MAX_MODEL_LEN")
export EVAL_MAX_MODEL_LEN
SERVE_MAX_MODEL_LEN="$EVAL_MAX_MODEL_LEN"
else
SERVE_MAX_MODEL_LEN="$BENCHMARK_MAX_MODEL_LEN"
fi

# MTP speculative decoding: DSv4 has one nextn predictor layer baked
# into the checkpoint. Bump for sweeps; vLLM reuses the head when
# SPEC_TOKENS > n_predict (must be divisible by n_predict=1, so any int
# works).
SPEC_TOKENS="${SPEC_TOKENS:-1}"
SPECULATIVE_CONFIG="{\"method\":\"deepseek_mtp\",\"num_speculative_tokens\":${SPEC_TOKENS}}"

# Start GPU monitoring (power, temperature, clocks every second)
start_gpu_monitor

set -x
vllm serve "$MODEL" --host 0.0.0.0 --port "$PORT" \
"${PARALLEL_ARGS[@]}" \
--pipeline-parallel-size 1 \
--kv-cache-dtype fp8 \
--trust-remote-code \
--block-size 256 \
--no-enable-prefix-caching \
"${EP_ARGS[@]}" \
--speculative-config "$SPECULATIVE_CONFIG" \
--compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}' \
--attention_config.use_fp4_indexer_cache True \
--tokenizer-mode deepseek_v4 \
--tool-call-parser deepseek_v4 \
--enable-auto-tool-choice \
--reasoning-parser deepseek_v4 \
--max-cudagraph-capture-size 2048 \
--max-model-len "$SERVE_MAX_MODEL_LEN" \
--max-num-batched-tokens "$MAX_NUM_BATCHED_TOKENS" > "$SERVER_LOG" 2>&1 &

SERVER_PID=$!

# Wait for server to be ready
wait_for_server_ready --port "$PORT" --server-log "$SERVER_LOG" --server-pid "$SERVER_PID"

pip install -q datasets pandas

# --dsv4 routes prompts through encoding_dsv4.py (PR #1153), which emits the
# <bos><User>...<Assistant><think> framing DeepSeek-V4-Pro expects. The DSv4-Pro
# tokenizer ships without a jinja chat_template, so plain --use-chat-template
# would crash; --dsv4 sidesteps that and satisfies the AGENTS.md rule that all
# MTP scripts must benchmark against chat-formatted inputs (MTP acceptance
# silently regresses on raw random tokens).
run_benchmark_serving \
--model "$MODEL" \
--port "$PORT" \
--backend vllm \
--input-len "$ISL" \
--output-len "$OSL" \
--random-range-ratio "$RANDOM_RANGE_RATIO" \
--num-prompts "$((CONC * 10))" \
--max-concurrency "$CONC" \
--result-filename "$RESULT_FILENAME" \
--result-dir /workspace/ \
--trust-remote-code \
--dsv4

# After throughput, run evaluation only if RUN_EVAL is true
if [ "${RUN_EVAL}" = "true" ]; then
run_eval --framework lm-eval --port "$PORT"
append_lm_eval_summary
fi

# Stop GPU monitoring
stop_gpu_monitor
set +x
12 changes: 12 additions & 0 deletions perf-changelog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1918,3 +1918,15 @@
- "Three CONC bands: A=TP8 (1-8), B=TP4 (16-128), C=DP4 dp-attn (64-512); B/C overlap at conc 64,128"
- "Configs: 1k1k and 8k1k, no validation.py / launcher / yaml-field changes (knob-free)"
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1180

- config-keys:
- dsv4-fp4-b300-vllm-mtp
description:
- "Add DeepSeek-V4-Pro FP4 B300 vLLM benchmark with MTP (Multi-Token Prediction) speculative decoding"
- "Image: vllm/vllm-openai:deepseekv4-cu130 (registers DeepSeekV4MTPModel; speculative.py remaps model_type=deepseek_v4 -> deepseek_mtp)"
- "Model: deepseek-ai/DeepSeek-V4-Pro"
- "MTP via --speculative-config '{\"method\":\"deepseek_mtp\",\"num_speculative_tokens\":1}' (DSv4-Pro publishes num_nextn_predict_layers=1; SPEC_TOKENS env var allows sweeping >1 since vLLM reuses the head when divisible by n_predict)"
- "Bench passes --dsv4 (auto-enables --use-chat-template) so MTP acceptance is measured against DSv4 chat-formatted prompts (encoding_dsv4.py from PR #1153) — random-token prompts silently regress acceptance"
- "Two CONC bands (TP-only, mirrors dsv4-fp4-b300-sglang-mtp coverage): TP8 conc 1-8, TP4 conc 4-32, for both 1k1k and 8k1k"
- "DP-attn (TP=4 ep=4 / TP=8 ep=8) supported in script via DP_ATTENTION=true but excluded from initial sweep until TP-only numbers land"
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1203
Loading