-
Notifications
You must be signed in to change notification settings - Fork 156
Add DSv4 B200 configs #1156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DSv4 B200 configs #1156
Changes from all commits
4a70e24
6ee148f
707f225
8ec1310
81594d7
f2fcfae
98d83ff
d7df79a
c18f413
1e95b00
5bab835
e28c638
44729b1
b653b7f
e86d6e9
a222193
34d9bc3
3f038c4
fae14d9
4d99225
acb8510
53b7f59
9c2f8c9
984064a
09599ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| #!/usr/bin/env bash | ||
|
|
||
| # DeepSeek-V4-Pro B200 single-node vLLM recipe derived from the B200 pareto | ||
| # sweep. TP mode (dp-attn=false) runs without expert parallel; DP mode | ||
| # (dp-attn=true) enables expert parallel (EP_SIZE=TP value = DP size). | ||
|
|
||
| source "$(dirname "$0")/../benchmark_lib.sh" | ||
|
|
||
| check_env_vars \ | ||
| MODEL \ | ||
| TP \ | ||
| DP_ATTENTION \ | ||
| CONC \ | ||
| ISL \ | ||
| OSL \ | ||
| MAX_MODEL_LEN \ | ||
| RANDOM_RANGE_RATIO \ | ||
| RESULT_FILENAME | ||
|
|
||
| if [[ -n "$SLURM_JOB_ID" ]]; then | ||
| echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME" | ||
| fi | ||
|
|
||
| nvidia-smi | ||
|
|
||
| hf download "$MODEL" | ||
|
|
||
| SERVER_LOG=/workspace/server.log | ||
| PORT=${PORT:-8888} | ||
|
|
||
| # DeepSeek-V4-Pro weights are large; engine startup can exceed the default | ||
| # 600s. Give it an hour to load. | ||
| export VLLM_ENGINE_READY_TIMEOUT_S=3600 | ||
|
|
||
| PARALLEL_ARGS=(--tensor-parallel-size "$TP" --data-parallel-size 1) | ||
| if [ "${DP_ATTENTION}" = "true" ]; then | ||
| PARALLEL_ARGS=(--tensor-parallel-size 1 --data-parallel-size "$TP") | ||
| fi | ||
|
|
||
| EP_ARGS=() | ||
| if [ "${EP_SIZE:-1}" -gt 1 ]; then | ||
| EP_ARGS=(--enable-expert-parallel) | ||
| fi | ||
|
|
||
| GMU_ARGS=() | ||
| if [ "${DP_ATTENTION}" = "true" ]; then | ||
| GMU_ARGS=(--gpu-memory-utilization 0.85) | ||
| fi | ||
|
|
||
| if [ "${ISL}" -eq 8192 ] && [ "${CONC}" -le 128 ]; then | ||
| MAX_NUM_BATCHED_TOKENS=${ISL} | ||
| else | ||
| MAX_NUM_BATCHED_TOKENS=2048 | ||
| fi | ||
|
|
||
| BENCHMARK_MAX_MODEL_LEN="$MAX_MODEL_LEN" | ||
| if [ "$ISL" -eq 1024 ] && [ "$OSL" -eq 1024 ]; then | ||
| BENCHMARK_MAX_MODEL_LEN=4096 | ||
| fi | ||
|
|
||
| if [ "${EVAL_ONLY}" = "true" ]; then | ||
| EVAL_MAX_MODEL_LEN=$(compute_eval_context_length "$MODEL" "$BENCHMARK_MAX_MODEL_LEN") | ||
| export EVAL_MAX_MODEL_LEN | ||
| SERVE_MAX_MODEL_LEN="$EVAL_MAX_MODEL_LEN" | ||
| else | ||
| SERVE_MAX_MODEL_LEN="$BENCHMARK_MAX_MODEL_LEN" | ||
| fi | ||
|
|
||
| # Start GPU monitoring (power, temperature, clocks every second) | ||
| start_gpu_monitor | ||
|
|
||
| set -x | ||
| vllm serve "$MODEL" --host 0.0.0.0 --port "$PORT" \ | ||
| "${PARALLEL_ARGS[@]}" \ | ||
| --pipeline-parallel-size 1 \ | ||
| --kv-cache-dtype fp8 \ | ||
| --trust-remote-code \ | ||
| --block-size 256 \ | ||
| --no-enable-prefix-caching \ | ||
| "${EP_ARGS[@]}" \ | ||
| "${GMU_ARGS[@]}" \ | ||
| --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}' \ | ||
| --attention_config.use_fp4_indexer_cache True \ | ||
| --tokenizer-mode deepseek_v4 \ | ||
| --tool-call-parser deepseek_v4 \ | ||
| --enable-auto-tool-choice \ | ||
| --reasoning-parser deepseek_v4 \ | ||
| --max-cudagraph-capture-size 2048 \ | ||
| --max-model-len "$SERVE_MAX_MODEL_LEN" \ | ||
| --max-num-batched-tokens "$MAX_NUM_BATCHED_TOKENS" > "$SERVER_LOG" 2>&1 & | ||
|
|
||
|
Comment on lines
+72
to
+91
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 The new Extended reasoning...The bug:
With DP=8 and per-replica default 256, the engine accepts at most 8×256 = 2048 concurrent sequences server-wide. So the CONC=4096 sweep point in the 1k1k DP-attn branch cannot actually be served at the requested concurrency — half the requests sit in the client-side or engine waiting queue while only ~2048 are processed in-flight. Why this matters for the sweep: This is a benchmark recipe whose entire point is to populate a Pareto curve. At CONC=4096 (and likely the second-highest point too) the reported throughput and latency reflect the server cap, not the requested in-flight count, polluting the curve. The output looks plausible (no crash, no error), so the issue is silent — exactly the kind of regression the verifiers flagged as "normal" rather than "nit." An internal contradiction in the script confirms intent: line 83 sets Sibling recipes consistently set this: every other vLLM script in Step-by-step proof:
Fix: Add |
||
| SERVER_PID=$! | ||
|
|
||
| # Wait for server to be ready | ||
| wait_for_server_ready --port "$PORT" --server-log "$SERVER_LOG" --server-pid "$SERVER_PID" | ||
|
|
||
| pip install -q datasets pandas | ||
|
|
||
| run_benchmark_serving \ | ||
| --model "$MODEL" \ | ||
| --port "$PORT" \ | ||
| --backend vllm \ | ||
| --input-len "$ISL" \ | ||
| --output-len "$OSL" \ | ||
| --random-range-ratio "$RANDOM_RANGE_RATIO" \ | ||
| --num-prompts "$((CONC * 10))" \ | ||
| --max-concurrency "$CONC" \ | ||
| --result-filename "$RESULT_FILENAME" \ | ||
| --result-dir /workspace/ \ | ||
| --trust-remote-code | ||
|
|
||
| # After throughput, run evaluation only if RUN_EVAL is true | ||
| if [ "${RUN_EVAL}" = "true" ]; then | ||
| run_eval --framework lm-eval --port "$PORT" | ||
| append_lm_eval_summary | ||
| fi | ||
|
|
||
| # Stop GPU monitoring | ||
| stop_gpu_monitor | ||
| set +x | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,14 @@ export PORT=8888 | |
| MODEL_CODE="${EXP_NAME%%_*}" | ||
| FRAMEWORK_SUFFIX=$([[ "$FRAMEWORK" == "trt" ]] && printf '_trt' || printf '') | ||
| SPEC_SUFFIX=$([[ "$SPEC_DECODING" == "mtp" ]] && printf '_mtp' || printf '') | ||
| # Prefer a framework-tagged script (e.g. dsv4_fp4_b200_vllm.sh) so models | ||
| # with multiple inference engines can coexist; fall back to the historical | ||
| # name without an engine suffix (`_trt` for trt, bare for everyone else). | ||
| BENCH_BASE="benchmarks/single_node/${MODEL_CODE}_${PRECISION}_b200" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double check this is back-compatible
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think this is back-compatible, it first checks whether
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it even need to be backwards compatible ? @Oseltamivir |
||
| BENCH_SCRIPT="${BENCH_BASE}_${FRAMEWORK}${SPEC_SUFFIX}.sh" | ||
| if [[ ! -f "$BENCH_SCRIPT" ]]; then | ||
| BENCH_SCRIPT="${BENCH_BASE}${FRAMEWORK_SUFFIX}${SPEC_SUFFIX}.sh" | ||
| fi | ||
|
|
||
| PARTITION="b200" | ||
| SQUASH_FILE="/tmp/gharunner/squash/$(echo "$IMAGE" | sed 's/[\/:@#]/_/g').sqsh" | ||
|
|
@@ -58,6 +66,6 @@ srun --jobid=$JOB_ID \ | |
| --container-mount-home \ | ||
| --container-workdir=$CONTAINER_MOUNT_DIR \ | ||
| --no-container-entrypoint --export=ALL \ | ||
| bash benchmarks/single_node/${MODEL_CODE}_${PRECISION}_b200${FRAMEWORK_SUFFIX}${SPEC_SUFFIX}.sh | ||
| bash "$BENCH_SCRIPT" | ||
|
|
||
| scancel $JOB_ID | ||
Uh oh!
There was an error while loading. Please reload this page.