Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions .github/configs/nvidia-master.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1670,39 +1670,40 @@ dsr1-fp4-b200-sglang:
- { tp: 8, ep: 8, conc-start: 4, conc-end: 16 }

dsv4-fp4-b200-sglang:
image: lmsysorg/sglang:deepseek-v4-blackwell
image: lmsysorg/sglang:deepseek-v4-blackwell@sha256:df18bfc4aa9ecf59451002b49ba00cae58042de9e2a96378bbd21b404dd62c7b
model: deepseek-ai/DeepSeek-V4-Pro
model-prefix: dsv4
runner: b200-dsv4
precision: fp4
framework: sglang
multinode: false
# Three recipes from https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4
# are selected inside benchmarks/single_node/dsv4_fp4_b200.sh by CONC:
# low-latency (CONC <= 32): TP-only
# balanced (32 < CONC <= 128): + DP-attn
# max-throughput (CONC > 128): + DP-attn
# Split so result filenames (ep=, dpa=) accurately reflect the recipe.
# Two recipes from https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4
# are selected inside benchmarks/single_node/dsv4_fp4_b200.sh by DP_ATTENTION:
# low-latency (DP_ATTENTION=false): TP-only, flashinfer_mxfp4
# DP-attention (DP_ATTENTION=true): DP-attn + DeepEP + mega_moe opts
# The DP-attention recipe covers both "balanced" (conc 64-128) and
# "max-throughput" (conc 256+) CONC ranges with identical flags;
# only --max-running-requests scales with CONC.
# ep is implicit in sglang: --moe-a2a-backend deepep forces ep_size=tp_size,
# while low-latency leaves ep_size at the default of 1.
seq-len-configs:
- isl: 1024
osl: 1024
search-space:
# low-latency
- { tp: 8, ep: 1, conc-start: 4, conc-end: 32 }
# balanced
# low-latency (DP_ATTENTION=false)
- { tp: 8, ep: 1, conc-start: 1, conc-end: 32 }
# DP-attention (DP_ATTENTION=true) — balanced CONC range
- { tp: 8, ep: 8, dp-attn: true, conc-start: 64, conc-end: 128 }
# max-throughput
# DP-attention (DP_ATTENTION=true) — max-throughput CONC range
- { tp: 8, ep: 8, dp-attn: true, conc-start: 256, conc-end: 1024 }
- isl: 8192
osl: 1024
search-space:
# low-latency
- { tp: 8, ep: 1, conc-start: 4, conc-end: 32 }
# balanced
# low-latency (DP_ATTENTION=false)
- { tp: 8, ep: 1, conc-start: 1, conc-end: 32 }
# DP-attention (DP_ATTENTION=true) — balanced CONC range
- { tp: 8, ep: 8, dp-attn: true, conc-start: 64, conc-end: 128 }
# max-throughput
# DP-attention (DP_ATTENTION=true) — max-throughput CONC range
- { tp: 8, ep: 8, dp-attn: true, conc-start: 256, conc-end: 512 }

# NOTE: At the time of submission, https://cookbook.sglang.io/autoregressive/DeepSeek/DeepSeek-R1
Expand Down
70 changes: 37 additions & 33 deletions benchmarks/single_node/dsv4_fp4_b200.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ source "$(dirname "$0")/../benchmark_lib.sh"
check_env_vars \
MODEL \
TP \
DP_ATTENTION \
CONC \
ISL \
OSL \
Expand All @@ -19,7 +20,13 @@ hf download "$MODEL"

nvidia-smi

# Common SGLANG env vars (apply to every config).
export SGLANG_JIT_DEEPGEMM_PRECOMPILE=0
export SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT=1
export SGLANG_OPT_USE_JIT_NORM=1
export SGLANG_OPT_USE_JIT_INDEXER_METADATA=1
export SGLANG_OPT_USE_TOPK_V2=1
export SGLANG_OPT_USE_CUSTOM_ALL_REDUCE_V2=1

# TODO(Cam): the lmsysorg/sglang:deepseek-v4-blackwell image installs sglang
# editable at /workspace/sglang/python; prior sglang tags used /sgl-workspace/sglang.
Expand All @@ -30,7 +37,7 @@ export SGLANG_JIT_DEEPGEMM_PRECOMPILE=0
SERVER_LOG="$PWD/server.log"
PORT=${PORT:-8888}

echo "TP: $TP, CONC: $CONC, ISL: $ISL, OSL: $OSL"
echo "TP: $TP, DP_ATTENTION: $DP_ATTENTION, CONC: $CONC, ISL: $ISL, OSL: $OSL"

EVAL_CONTEXT_ARGS=""
if [ "${EVAL_ONLY}" = "true" ]; then
Expand All @@ -40,47 +47,41 @@ fi

start_gpu_monitor --output "$PWD/gpu_metrics.csv"

# Three recipes from https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4
# (spec-decoding / MTP and prefix-caching flags dropped for the baseline):
# - low-latency (CONC <= 32): TP-only, chunked-prefill, disable autotune
# - balanced (32 < CONC <= 128): + DP-attn, max-running-requests=128
# - max-throughput (CONC > 128): + DP-attn, max-running-requests=256
# Pick the parallelism + MoE backend based on DP_ATTENTION (mirrors the vllm
# script's pattern). DP-attention turns on EP-MoE (deepep) and the related
# mega_moe optimizations; single-instance uses flashinfer_mxfp4.
DEEPEP_CONFIG='{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}'

if [[ $CONC -le 32 ]]; then
RECIPE=low-latency
RECIPE_FLAGS=(
--moe-runner-backend flashinfer_mxfp4
--chunked-prefill-size 4096
--disable-flashinfer-autotune
--mem-fraction-static 0.82
)
elif [[ $CONC -le 128 ]]; then
RECIPE=balanced
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=256
RECIPE_FLAGS=(
if [ "${DP_ATTENTION}" = "true" ]; then
export SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE=1
export SGLANG_OPT_FIX_HASH_MEGA_MOE=1
export SGLANG_OPT_USE_FAST_MASK_EP=1
export SGLANG_OPT_FIX_MEGA_MOE_MEMORY=1
export SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK=4096
export SGLANG_OPT_FIX_NEXTN_MEGA_MOE=1
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=0
PARALLEL_ARGS=(
--dp-size "$TP"
--enable-dp-attention
--moe-a2a-backend deepep
--deepep-config "$DEEPEP_CONFIG"
--mem-fraction-static 0.82
--cuda-graph-max-bs 64
--max-running-requests 128
--chunked-prefill-size 32768
)
else
RECIPE=max-throughput
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=256
RECIPE_FLAGS=(
--dp-size "$TP"
--enable-dp-attention
--moe-a2a-backend deepep
--deepep-config "$DEEPEP_CONFIG"
--mem-fraction-static 0.82
--cuda-graph-max-bs 64
--max-running-requests 256
PARALLEL_ARGS=(
--moe-runner-backend flashinfer_mxfp4
--chunked-prefill-size 8192
--disable-flashinfer-autotune
)
fi
echo "Recipe: $RECIPE (CONC=$CONC)"

# Print all SGLANG_* env vars to both the CI step log and server.log so the
# launch config is auditable from the result artifact alone.
{
echo "=== SGLANG_* env vars at launch ==="
env | grep -E '^SGLANG_' | sort
echo "==================================="
} | tee "$SERVER_LOG"

set -x
PYTHONNOUSERSITE=1 sglang serve \
Expand All @@ -90,7 +91,10 @@ PYTHONNOUSERSITE=1 sglang serve \
--trust-remote-code \
--tp $TP \
--disable-radix-cache \
"${RECIPE_FLAGS[@]}" $EVAL_CONTEXT_ARGS > $SERVER_LOG 2>&1 &
--max-running-requests "$((CONC * 3 / 2))" \
--mem-fraction-static 0.90 \
--swa-full-tokens-ratio 0.1 \
"${PARALLEL_ARGS[@]}" $EVAL_CONTEXT_ARGS >> $SERVER_LOG 2>&1 &

SERVER_PID=$!

Expand Down
10 changes: 9 additions & 1 deletion perf-changelog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1877,4 +1877,12 @@
- "Image pinned to lmsysorg/sglang:deepseek-v4-b300@sha256:26e116bd211e300dbb76924d56c5cbe6cc3ee5ee2fe314859cb8774f5bc070f3"
- "DP-attention path enables SGLANG_OPT_SWA_EVICT_DROP_PAGE_MARGIN=1 for better SWA eviction behavior"
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1185


- config-keys:
- dsv4-fp4-b200-sglang
description:
- "Two-recipe dispatch for DeepSeek-V4-Pro on B200, selected by DP_ATTENTION knob: low-latency (TP=8, EP=1, flashinfer_mxfp4) for conc 1-32, DP-attention (TP=8, EP=8, DP-attn + DeepEP + mega_moe) for conc 64-{512,1024}. The DP-attention recipe uses identical flags across balanced and max-throughput CONC ranges; only --max-running-requests scales with CONC."
- "Recipes from https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4"
Comment thread
cquil11 marked this conversation as resolved.
- "Image pinned to lmsysorg/sglang:deepseek-v4-blackwell@sha256:df18bfc4aa9ecf59451002b49ba00cae58042de9e2a96378bbd21b404dd62c7b"
- "Adds SGLANG_OPT_* env knobs (SWA_SPLIT_LEAF_ON_INSERT, USE_JIT_NORM, USE_JIT_INDEXER_METADATA, USE_TOPK_V2, USE_CUSTOM_ALL_REDUCE_V2)"
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1187
Loading