Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/configs/nvidia-master.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7723,3 +7723,34 @@ dsv4-fp4-gb200-dynamo-vllm:
tp: 8
ep: 8
dp-attn: true

dsv4-fp4-gb300-dynamo-sglang:
image: lmsysorg/sglang:deepseek-v4-grace-blackwell
model: deepseek-ai/DeepSeek-V4-Pro
model-prefix: dsv4
runner: gb300-cw
precision: fp4
framework: dynamo-sglang
multinode: true
disagg: true
seq-len-configs:
- isl: 8192
osl: 1024
search-space:
# Max throughput: 7 prefills (TP=4 / DP=4 / EP=4) + 1 decode
# (TP=8 / DP=8 / EP=8 wideep). 9 nodes.
# Reference: Job 588 on gb300-cw — 359,226 total_token_throughput
# (9,979 tok/s/gpu).
- conc-list: [8192]
prefill:
num-worker: 7
tp: 4
ep: 4
dp-attn: true
additional-settings:
- "CONFIG_FILE=recipes/sglang/deepseek-v4/8k1k/disagg-gb300-7p1d-dep4-dep8.yaml"
decode:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
5 changes: 5 additions & 0 deletions .github/configs/runners.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,8 @@ gb300:
- 'gb300-nv_0'
- 'gb300-nv_1'
- 'gb300-nv_2'
gb300-cw:
- 'gb300-cw_0'
- 'gb300-cw_1'
- 'gb300-cw_2'
- 'gb300-cw_3'
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
name: "dsv4-sglang-disagg-gb300-7p1d-dep4-dep8"

# 8k/1k high-throughput topology for DeepSeek-V4-Pro on GB300.
#
# Derived from Job 588 on the gb300-cw cluster (yangminl@slurm-login-0,
# 2026-04-28). That run achieved 359,226 total_token_throughput
# (9,979 tok/s/gpu) on 9 GB300 nodes (36 GPUs) with the patched sglang
# container + DeepEP disaggregated inference.
#
# Topology: 7 prefill workers (TP=4 / DP=4 / EP=4, 1 node each) +
# 1 decode worker (TP=8 / DP=8 / EP=8 "wideep", spanning 2 nodes).
# 9 nodes / 36 GPUs total.
#
# Cluster-specific items removed from the original config:
# - slurm.partition (original: hpc-mid)
# - frontend.nginx_container (original: /mnt/home/yangminl/containers/nginx-1.27.4.sqsh)
# - extra_mount (original: /mnt/home/yangminl/sglang-patched/sglang)
# - SGLANG_DG_CACHE_DIR, SGLANG_JIT_DEEPGEMM_PRECOMPILE (host-specific
# deepgemm cache); replaced with SGLANG_JIT_DEEPGEMM_FAST_WARMUP=1

model:
path: "deepseek-v4-pro"
container: "lmsysorg/sglang:deepseek-v4-grace-blackwell"
precision: "fp4"

dynamo:
hash: "9d3c913d300eb368cda28b3f98a23a5762621e0d"
install: true

slurm:
time_limit: "03:00:00"

sbatch_directives:
cpus-per-task: "144"
mem: "0"

health_check:
max_attempts: 1440
interval_seconds: 10

resources:
gpu_type: "gb300"
gpus_per_node: 4
prefill_nodes: 7
prefill_workers: 7
decode_nodes: 2
decode_workers: 1
gpus_per_prefill: 4
gpus_per_decode: 8

frontend:
type: dynamo
enable_multiple_frontends: true
num_additional_frontends: 8

backend:
type: sglang

prefill_environment:
PYTHONUNBUFFERED: "1"
SGLANG_JIT_DEEPGEMM_FAST_WARMUP: "1"
SGLANG_ENABLE_THINKING: "1"
SGLANG_REASONING_EFFORT: "max"
SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT: "1"
SGLANG_OPT_SWA_EVICT_DROP_PAGE_MARGIN: "1"
SGLANG_OPT_USE_JIT_NORM: "1"
SGLANG_OPT_USE_JIT_INDEXER_METADATA: "1"
SGLANG_OPT_USE_TOPK_V2: "1"
SGLANG_OPT_USE_CUSTOM_ALL_REDUCE_V2: "1"
SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE: "1"
SGLANG_OPT_FIX_HASH_MEGA_MOE: "1"
SGLANG_OPT_USE_FAST_MASK_EP: "1"
SGLANG_OPT_FIX_MEGA_MOE_MEMORY: "1"
SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK: "9216"
SGLANG_OPT_FIX_NEXTN_MEGA_MOE: "1"
SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK: "0"
NCCL_MNNVL_ENABLE: "1"
NCCL_CUMEM_ENABLE: "1"
SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
MC_FORCE_MNNVL: "1"
SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW: "1"
DYN_SKIP_SGLANG_LOG_FORMATTING: "1"
SGLANG_LOG_FORWARD_ITERS: "1"
SGLANG_LOG_MS: "1"
SGLANG_REQUEST_STATE_WAIT_TIMEOUT: "60"

decode_environment:
PYTHONUNBUFFERED: "1"
SGLANG_JIT_DEEPGEMM_FAST_WARMUP: "1"
SGLANG_ENABLE_THINKING: "1"
SGLANG_REASONING_EFFORT: "max"
SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT: "1"
SGLANG_OPT_SWA_EVICT_DROP_PAGE_MARGIN: "1"
SGLANG_OPT_USE_JIT_NORM: "1"
SGLANG_OPT_USE_JIT_INDEXER_METADATA: "1"
SGLANG_OPT_USE_TOPK_V2: "1"
SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE: "1"
SGLANG_OPT_FIX_HASH_MEGA_MOE: "1"
SGLANG_OPT_USE_FAST_MASK_EP: "1"
SGLANG_OPT_FIX_MEGA_MOE_MEMORY: "1"
SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK: "1152"
SGLANG_OPT_FIX_NEXTN_MEGA_MOE: "1"
SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK: "0"
NCCL_MNNVL_ENABLE: "1"
NCCL_CUMEM_ENABLE: "1"
SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
MC_FORCE_MNNVL: "1"
SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW: "1"
DYN_SKIP_SGLANG_LOG_FORMATTING: "1"
SGLANG_LOG_FORWARD_ITERS: "1"
SGLANG_LOG_MS: "1"
SGLANG_REQUEST_STATE_WAIT_TIMEOUT: "60"
# SGLANG_OPT_USE_CUSTOM_ALL_REDUCE_V2 intentionally NOT set: CAR_V2
# is single-node only and corrupts results in 2-node decode setups.

sglang_config:
prefill:
served-model-name: "deepseek-ai/DeepSeek-V4-Pro"
trust-remote-code: true
watchdog-timeout: 86400
skip-tokenizer-init: true
stream-interval: 30

tensor-parallel-size: 4
data-parallel-size: 4
expert-parallel-size: 4

enable-dp-attention: true
moe-a2a-backend: "deepep"
deepep-config: '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}'

disaggregation-mode: "prefill"
disaggregation-transfer-backend: mooncake

mem-fraction-static: 0.90
max-running-requests: 512
cuda-graph-max-bs: 512
chunked-prefill-size: 32768

decode:
served-model-name: "deepseek-ai/DeepSeek-V4-Pro"
trust-remote-code: true
watchdog-timeout: 86400
skip-tokenizer-init: true
stream-interval: 30

tensor-parallel-size: 8
data-parallel-size: 8
expert-parallel-size: 8

enable-dp-attention: true
enable-dp-lm-head: true

moe-a2a-backend: "deepep"
deepep-config: '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}'

disaggregation-mode: "decode"
disaggregation-transfer-backend: mooncake

mem-fraction-static: 0.94
swa-full-tokens-ratio: 0.15
context-length: 16384
max-running-requests: 9216
cuda-graph-max-bs: 1152

benchmark:
type: "custom"
command: |
set -e
REPO=/configs/upstream-sa-bench/InferenceX
[ -d "$REPO" ] || git clone https://github.com/SemiAnalysisAI/InferenceX.git "$REPO"
cd "$REPO/utils/bench_serving"
python3 benchmark_serving.py \
--backend vllm --model deepseek-ai/DeepSeek-V4-Pro --tokenizer /model \
--host 127.0.0.1 --port 8000 --endpoint /v1/completions \
--dataset-name random \
--random-input-len 8192 --random-output-len 1024 --random-range-ratio 0.8 \
--random-num-workers 96 \
--num-prompts 40960 --max-concurrency 8192 --request-rate 48 \
--num-warmups 512 \
--ignore-eos --trust-remote-code \
--percentile-metrics ttft,tpot,itl,e2el \
--save-result --result-dir /logs --result-filename results.json
17 changes: 13 additions & 4 deletions benchmarks/single_node/dsv4_fp4_b300_sglang.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ if [ "${DP_ATTENTION}" = "true" ]; then
# ep=8 in the yaml signals the mega_moe deepep backend; check high-conc
# recipes first (they also have ep=8) so they aren't shadowed by the
# medium-conc EP_SIZE=8 branch below.
if [ "$CONC" = "2048" ] || [ "$CONC" = "4096" ] || [ "$CONC" = "8192" ]; then
if [ "$CONC" = "2048" ] || [ "$CONC" = "4096" ] || [ "$CONC" = "8192" ] || [ "$CONC" = "12288" ]; then
export NVSHMEM_DISABLE_IB=1
export SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW=1
export SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE=1
Expand All @@ -98,14 +98,23 @@ if [ "${DP_ATTENTION}" = "true" ]; then
MEM_FRACTION_STATIC=0.835
SWA_FULL_TOKENS_RATIO=0.075
TOKENIZER_WORKER_NUM=8
else
elif [ "$CONC" = "8192" ]; then
export SGLANG_OPT_USE_ONLINE_COMPRESS=1
export SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK=8256
CUDA_GRAPH_MAX_BS=1088
MAX_RUNNING_REQUESTS=8192
MEM_FRACTION_STATIC=0.80
SWA_FULL_TOKENS_RATIO=0.3
TOKENIZER_WORKER_NUM=16
else
export SGLANG_LOG_FORWARD_ITERS=1
export SGLANG_OPT_USE_ONLINE_COMPRESS=1
export SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK=8256
CUDA_GRAPH_MAX_BS=1600
MAX_RUNNING_REQUESTS=12288
MEM_FRACTION_STATIC=0.72
SWA_FULL_TOKENS_RATIO=0.3
TOKENIZER_WORKER_NUM=16
fi
PARALLEL_ARGS=(
--dp-size "$TP"
Expand All @@ -117,10 +126,10 @@ if [ "${DP_ATTENTION}" = "true" ]; then
--tokenizer-worker-num "$TOKENIZER_WORKER_NUM"
--enable-prefill-delayer
)
if [ "$CONC" = "4096" ]; then
if [ "$CONC" = "4096" ] || [ "$CONC" = "12288" ]; then
PARALLEL_ARGS+=(--decode-log-interval 5)
fi
if [ "$CONC" = "8192" ]; then
if [ "$CONC" = "8192" ] || [ "$CONC" = "12288" ]; then
PARALLEL_ARGS+=(--stream-interval 30)
fi
elif [ "${EP_SIZE}" = "8" ]; then
Expand Down
7 changes: 7 additions & 0 deletions perf-changelog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1992,3 +1992,10 @@
- "Add conc=8192 recipe for 1k1k: deepep mega_moe backend with cuda-graph-max-bs 1088, max-running-requests 8192, mem-fraction-static 0.80, swa-full-tokens-ratio 0.3, tokenizer-worker-num 16"
- "conc=8192 enables SGLANG_OPT_USE_ONLINE_COMPRESS=1 and --stream-interval 30"
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1209

- config-keys:
- dsv4-fp4-gb300-dynamo-sglang
description:
- "Add GB300 Dynamo SGLang disaggregated 7P+2D max-throughput recipe (Job 588)"
- "Topology: 7 prefill (TP=4/DP=4/EP=4 DeepEP) + 1 decode (TP=8/DP=8/EP=8 wideep), 9 nodes, mooncake transfer"
- "Benchmark: 8k/1k random, conc=8192, rate=48, 40960 prompts"
Loading