diff --git a/.gitignore b/.gitignore index 7b822165d3eb..9596421b182a 100644 --- a/.gitignore +++ b/.gitignore @@ -241,3 +241,5 @@ vllm/grpc/vllm_engine_pb2.pyi # Ignore generated cpu headers csrc/cpu/cpu_attn_dispatch_generated.h + +logs/ \ No newline at end of file diff --git a/docs/benchmarking/sweeps.md b/docs/benchmarking/sweeps.md index 41a799cf2109..2b54a0956a64 100644 --- a/docs/benchmarking/sweeps.md +++ b/docs/benchmarking/sweeps.md @@ -132,13 +132,66 @@ The algorithm for exploring different workload levels can be summarized as follo You can override the number of iterations in the algorithm by setting `--workload-iters`. -!!! tip - This is our equivalent of [GuideLLM's `--profile sweep`](https://github.com/vllm-project/guidellm/blob/v0.5.3/src/guidellm/benchmark/profiles.py#L575). +!!! important + SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`. + + For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value. + +### Optuna auto-tuner + +`vllm bench sweep serve_optuna` uses Optuna to tune serve arguments and scores each trial across multiple benchmark concurrencies. + +The score formula is: + +`sum(mean(score_metric) / concurrency)` + +where `mean(score_metric)` is computed across `--num-runs` for that concurrency. + +1. (Optional) Create a JSON search space file: + +```json +{ + "gpu_memory_utilization": { "type": "float", "low": 0.7, "high": 0.98, "step": 0.02 }, + "max_num_batched_tokens": { "type": "categorical", "choices": [null, 512, 1024, 2048, 4096, 8192] }, + "max_num_seqs": { "type": "categorical", "choices": [null, 8, 16, 32, 64, 128, 256] }, + "enable_chunked_prefill": { "type": "bool" }, + "enable_prefix_caching": { "type": "bool" } +} +``` + +If `--search-space` is omitted, `serve_optuna` uses built-in defaults: +- `gpu_memory_utilization` in `[0.7, 0.98]` (step `0.02`) +- `max_num_batched_tokens` in `[null, 512, 1024, 2048, 4096, 8192]` +- `max_num_seqs` in `[null, 8, 16, 32, 64, 128, 256]` +- `enable_chunked_prefill` in `[true, false]` +- `enable_prefix_caching` in `[true, false]` + +2. Run the optimizer: + +```bash +vllm bench sweep serve_optuna \ + --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ + --search-space benchmarks/search_space.json \ + --score-metric total_token_throughput \ + --score-concurrencies 1,8,64,256 \ + --n-trials 20 \ + -o benchmarks/results +``` + +`--bench-cmd` is optional. When omitted, `serve_optuna` uses `vllm bench serve` +and auto-fills model/base-url/tokenizer from `--serve-cmd`. + +By default, `serve_optuna` also launches the best server configuration at the end. +Use `--no-start-best-server` to disable this behavior. + +3. Inspect outputs under the timestamped run directory: - In general, `--workload-var max_concurrency` produces more reliable results because it directly controls the workload imposed on the vLLM engine. - Nevertheless, we default to `--workload-var request_rate` to maintain similar behavior as GuideLLM. +- `baseline.json`: baseline run score and full per-concurrency benchmark payload. +- `trials.json`: all trial records (`complete`, `pruned`, baseline). +- `best_params.json`: best Optuna parameters. +- `best.json`: best trial score and benchmark payload. -## Startup Benchmark +### Startup `vllm bench sweep startup` runs `vllm bench startup` across parameter combinations to compare cold/warm startup time for different engine settings. diff --git a/docs/cli/README.md b/docs/cli/README.md index c708eb795898..f8d45ea21ea0 100644 --- a/docs/cli/README.md +++ b/docs/cli/README.md @@ -9,7 +9,7 @@ vllm --help Available Commands: ```bash -vllm {chat,complete,serve,bench,collect-env,run-batch} +vllm {chat,complete,serve,serve-optuna,bench,collect-env,run-batch} ``` ## serve @@ -147,6 +147,24 @@ vllm bench throughput \ See [vllm bench throughput](./bench/throughput.md) for the full reference of all available arguments. +## serve-optuna + +Tune `vllm serve` parameters with Optuna and benchmark scoring across multiple concurrencies. + +```bash +vllm serve-optuna \ + --serve-cmd 'vllm serve Qwen/Qwen3-0.6B' \ + --score-concurrencies 1,8,64,256 \ + --n-trials 20 \ + -o benchmarks/results +``` + +`--search-space` is optional. If omitted, vLLM uses built-in default serve tuning ranges. +`--bench-cmd` is optional. If omitted, vLLM auto-fills model/base-url/tokenizer from `--serve-cmd`. +By default, the best server config is started after optimization. Use `--no-start-best-server` to skip. + +See [vllm bench sweep serve_optuna](./bench/sweep/serve_optuna.md) for the full reference of all available arguments. + ## collect-env Start collecting environment information. diff --git a/docs/cli/bench/sweep/serve_optuna.md b/docs/cli/bench/sweep/serve_optuna.md new file mode 100644 index 000000000000..f4e175917ca3 --- /dev/null +++ b/docs/cli/bench/sweep/serve_optuna.md @@ -0,0 +1,9 @@ +# vllm bench sweep serve_optuna + +## JSON CLI Arguments + +--8<-- "docs/cli/json_tip.inc.md" + +## Arguments + +--8<-- "docs/generated/argparse/bench_sweep_serve_optuna.inc.md" diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index 9d87f88f5666..d237766f9a26 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -103,6 +103,12 @@ def auto_mock(module_name: str, attr: str, max_mocks: int = 100): bench_sweep_serve_workload = auto_mock( "vllm.benchmarks.sweep.serve_workload", "SweepServeWorkloadArgs" ) +bench_sweep_serve_optuna = auto_mock( + "vllm.benchmarks.sweep.serve_optuna", "SweepServeOptunaArgs" +) +bench_sweep_serve_sla = auto_mock( + "vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs" +) bench_throughput = auto_mock("vllm.benchmarks", "throughput") AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs") EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs") @@ -232,6 +238,10 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): "bench_sweep_serve_workload": create_parser( bench_sweep_serve_workload.add_cli_args ), + "bench_sweep_serve_optuna": create_parser( + bench_sweep_serve_optuna.add_cli_args + ), + "bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args), "bench_throughput": create_parser(bench_throughput.add_cli_args), } diff --git a/pyproject.toml b/pyproject.toml index fad8c8c687a1..58973ebc394b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ Slack="https://slack.vllm.ai/" [project.scripts] vllm = "vllm.entrypoints.cli.main:main" +unieinfra = "vllm.entrypoints.unieai:main" [project.entry-points."vllm.general_plugins"] lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver" diff --git a/requirements/common.txt b/requirements/common.txt index 05666c5d14b0..2923ba506c3b 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -55,3 +55,5 @@ opentelemetry-sdk >= 1.27.0 opentelemetry-api >= 1.27.0 opentelemetry-exporter-otlp >= 1.27.0 opentelemetry-semantic-conventions-ai >= 0.4.1 +cryptography # Required for UnieAI license verification in unieai.py +uvloop # Optional event loop for performance in unieai.py and other entrypoints \ No newline at end of file diff --git a/tests/benchmarks/sweep/test_serve_optuna.py b/tests/benchmarks/sweep/test_serve_optuna.py new file mode 100644 index 000000000000..7a0927a71bce --- /dev/null +++ b/tests/benchmarks/sweep/test_serve_optuna.py @@ -0,0 +1,536 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import argparse +import contextlib +from types import SimpleNamespace + +import pytest + +from vllm.benchmarks.sweep.param_sweep import ParameterSweep +from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem +from vllm.benchmarks.sweep import serve_optuna +from vllm.benchmarks.sweep.serve_optuna import ( + SweepServeOptunaArgs, + default_search_space, + parse_score_concurrencies, + run_main, + suggest_trial_params, +) + + +def test_parse_score_concurrencies(): + assert parse_score_concurrencies("1,8,64,256") == [1, 8, 64, 256] + + +def test_parse_score_concurrencies_invalid(): + with pytest.raises(ValueError): + parse_score_concurrencies("1,0") + + +def test_suggest_trial_params(): + search_space = { + "gpu_memory_utilization": { + "type": "float", + "low": 0.7, + "high": 0.98, + "step": 0.02, + }, + "max_num_seqs": {"type": "int", "low": 8, "high": 32, "step": 8}, + "max_num_batched_tokens": { + "type": "categorical", + "choices": [None, 1024, 2048], + }, + "enable_prefix_caching": {"type": "bool"}, + "constant_param": "keep", + } + + class DummyTrial: + def suggest_float(self, name, low, high, step=None, log=False): # noqa: ARG002 + assert name == "gpu_memory_utilization" + return 0.78 + + def suggest_int(self, name, low, high, step=1, log=False): # noqa: ARG002 + assert name == "max_num_seqs" + return 24 + + def suggest_categorical(self, name, choices): # noqa: ARG002 + if name == "max_num_batched_tokens": + return 1024 + assert name == "enable_prefix_caching" + return True + + trial = DummyTrial() + + suggested = suggest_trial_params(trial, search_space) + + assert suggested == { + "gpu_memory_utilization": 0.78, + "max_num_seqs": 24, + "max_num_batched_tokens": 1024, + "enable_prefix_caching": True, + "constant_param": "keep", + } + + +def test_from_cli_args_uses_default_search_space_when_omitted(): + parser = argparse.ArgumentParser() + SweepServeOptunaArgs.add_cli_args(parser) + + parsed = parser.parse_args( + [ + "--serve-cmd", + "vllm serve Qwen/Qwen3-0.6B", + "--bench-cmd", + "vllm bench serve --model Qwen/Qwen3-0.6B", + "--dry-run", + ] + ) + + args = SweepServeOptunaArgs.from_cli_args(parsed) + assert args.search_space == default_search_space() + assert args.start_best_server is True + + +def test_from_cli_args_reads_search_space_file(tmp_path): + parser = argparse.ArgumentParser() + SweepServeOptunaArgs.add_cli_args(parser) + + search_space_path = tmp_path / "search_space.json" + payload = {"max_num_seqs": {"type": "categorical", "choices": [16, 32, 64]}} + search_space_path.write_text(json.dumps(payload), encoding="utf-8") + + parsed = parser.parse_args( + [ + "--serve-cmd", + "vllm serve Qwen/Qwen3-0.6B", + "--bench-cmd", + "vllm bench serve --model Qwen/Qwen3-0.6B", + "--search-space", + str(search_space_path), + "--dry-run", + ] + ) + + args = SweepServeOptunaArgs.from_cli_args(parsed) + assert args.search_space == payload + + +def test_from_cli_args_populates_bench_cmd_from_serve_cmd(): + parser = argparse.ArgumentParser() + SweepServeOptunaArgs.add_cli_args(parser) + + parsed = parser.parse_args( + [ + "--serve-cmd", + "vllm serve /tmp/model --served-model-name test --host 0.0.0.0 --port 12470", + "--dry-run", + ] + ) + + args = SweepServeOptunaArgs.from_cli_args(parsed) + assert args.bench_cmd[0:3] == ["vllm", "bench", "serve"] + assert "--model" in args.bench_cmd + assert "--base-url" in args.bench_cmd + assert "http://127.0.0.1:12470" in args.bench_cmd + + +def test_drop_none_values_for_serve_overrides(): + cleaned = serve_optuna._drop_none_values( + ParameterSweepItem( + { + "gpu_memory_utilization": 0.9, + "max_num_batched_tokens": None, + "max_num_seqs": 64, + } + ) + ) + assert cleaned == { + "gpu_memory_utilization": 0.9, + "max_num_seqs": 64, + } + + +def test_infer_max_model_len_from_serve_cmd_reads_flag(): + serve_cmd = [ + "vllm", + "serve", + "Qwen/Qwen3-0.6B", + "--max-model-len", + "8192", + ] + assert serve_optuna._infer_max_model_len_from_serve_cmd(serve_cmd) == 8192 + + +def test_sanitize_serve_trial_params_drops_invalid_max_num_batched_tokens(): + params = { + "gpu_memory_utilization": 0.98, + "max_num_batched_tokens": 8192, + "max_num_seqs": 32, + } + sanitized = serve_optuna._sanitize_serve_trial_params( + params, + inferred_max_model_len=40960, + ) + assert sanitized["max_num_batched_tokens"] is None + assert sanitized["gpu_memory_utilization"] == 0.98 + assert sanitized["max_num_seqs"] == 32 + + +def test_run_main_writes_outputs(tmp_path, monkeypatch): + class FakeTrialState: + COMPLETE = "complete" + + class FakeTrial: + def __init__(self, number: int): + self.number = number + self.params: dict[str, object] = {} + self.user_attrs: dict[str, object] = {} + self.state = FakeTrialState.COMPLETE + self.value: float | None = None + + def suggest_float(self, name, low, high, step=None, log=False): # noqa: ARG002 + value = low + self.params[name] = value + return value + + def suggest_int(self, name, low, high, step=1, log=False): # noqa: ARG002 + value = low + self.params[name] = value + return value + + def suggest_categorical(self, name, choices): + value = choices[0] + self.params[name] = value + return value + + def set_user_attr(self, key: str, value: object) -> None: + self.user_attrs[key] = value + + class FakeStudy: + def __init__(self, direction: str): + self.direction = direction + self.trials: list[FakeTrial] = [] + self.best_trial: FakeTrial | None = None + + def optimize(self, objective, n_trials: int): + for trial_number in range(n_trials): + trial = FakeTrial(trial_number) + trial.value = objective(trial) + self.trials.append(trial) + + if self.direction == "minimize": + self.best_trial = min(self.trials, key=lambda trial: trial.value or 0.0) + else: + self.best_trial = max(self.trials, key=lambda trial: trial.value or 0.0) + + fake_optuna = SimpleNamespace( + samplers=SimpleNamespace(TPESampler=lambda seed=None: object()), # noqa: ARG005 + trial=SimpleNamespace(TrialState=FakeTrialState), + TrialPruned=RuntimeError, + create_study=lambda **kwargs: FakeStudy(kwargs["direction"]), + ) + monkeypatch.setattr(serve_optuna, "optuna", fake_optuna) + + def mock_evaluate_configuration(*args, **kwargs): + output_dir = kwargs["output_dir"] + if output_dir.name == "baseline_runs": + return 10.0, {"score": 10.0, "runs": []} + if output_dir.name.startswith("trial="): + return 12.0, {"score": 12.0, "runs": []} + raise AssertionError(f"unexpected output dir: {output_dir}") + + monkeypatch.setattr( + "vllm.benchmarks.sweep.serve_optuna.evaluate_configuration", + mock_evaluate_configuration, + ) + + args = SweepServeOptunaArgs( + serve_cmd=["vllm", "serve", "Qwen/Qwen3-0.6B"], + bench_cmd=["vllm", "bench", "serve", "--model", "Qwen/Qwen3-0.6B"], + after_bench_cmd=[], + show_stdout=False, + serve_params=ParameterSweep.from_records([{}]), + bench_params=ParameterSweep.from_records([{}]), + output_dir=tmp_path, + num_runs=1, + dry_run=False, + resume=None, + link_vars=[], + server_ready_timeout=1, + search_space={}, + n_trials=1, + direction="maximize", + score_metric="total_token_throughput", + score_concurrencies=[1, 8], + baseline_params=ParameterSweepItem(), + fixed_serve_overrides=ParameterSweepItem(), + fixed_bench_overrides=ParameterSweepItem(), + study_name="test-study", + sampler_seed=0, + start_best_server=False, + ) + + best_record = run_main(args) + + assert best_record is not None + assert best_record["score"] == 12.0 + + run_dirs = [path for path in tmp_path.iterdir() if path.is_dir()] + assert len(run_dirs) == 1 + run_dir = run_dirs[0] + + with (run_dir / "baseline.json").open("rb") as f: + baseline = json.load(f) + with (run_dir / "best.json").open("rb") as f: + best = json.load(f) + with (run_dir / "best_params.json").open("rb") as f: + best_params = json.load(f) + with (run_dir / "trials.json").open("rb") as f: + trials = json.load(f) + + assert baseline["score"] == 10.0 + assert best["score"] == 12.0 + assert best_params == {} + assert len(trials) == 2 + assert trials[0]["state"] == "baseline" + assert trials[1]["state"] == "complete" + + +def test_run_main_starts_best_server_when_enabled(tmp_path, monkeypatch): + class FakeTrialState: + COMPLETE = "complete" + + class FakeTrial: + def __init__(self, number: int): + self.number = number + self.params = {"gpu_memory_utilization": 0.9} + self.user_attrs: dict[str, object] = {} + self.state = FakeTrialState.COMPLETE + self.value = 1.0 + + def set_user_attr(self, key: str, value: object) -> None: + self.user_attrs[key] = value + + class FakeStudy: + def __init__(self, direction: str): # noqa: ARG002 + self.trials = [FakeTrial(0)] + self.best_trial = self.trials[0] + + def optimize(self, objective, n_trials: int): # noqa: ARG002 + trial = self.trials[0] + trial.value = objective(trial) + + fake_optuna = SimpleNamespace( + samplers=SimpleNamespace(TPESampler=lambda seed=None: object()), # noqa: ARG005 + trial=SimpleNamespace(TrialState=FakeTrialState), + TrialPruned=RuntimeError, + create_study=lambda **kwargs: FakeStudy(kwargs["direction"]), + ) + monkeypatch.setattr(serve_optuna, "optuna", fake_optuna) + + def mock_evaluate_configuration(*args, **kwargs): # noqa: ARG001 + return 1.0, {"score": 1.0, "runs": []} + + monkeypatch.setattr( + "vllm.benchmarks.sweep.serve_optuna.evaluate_configuration", + mock_evaluate_configuration, + ) + + start_calls = {"count": 0} + + def mock_start_best_server(*args, **kwargs): # noqa: ARG001 + start_calls["count"] += 1 + return 12345 + + monkeypatch.setattr( + "vllm.benchmarks.sweep.serve_optuna._start_best_server", + mock_start_best_server, + ) + + args = SweepServeOptunaArgs( + serve_cmd=["vllm", "serve", "Qwen/Qwen3-0.6B"], + bench_cmd=["vllm", "bench", "serve", "--model", "Qwen/Qwen3-0.6B"], + after_bench_cmd=[], + show_stdout=False, + serve_params=ParameterSweep.from_records([{}]), + bench_params=ParameterSweep.from_records([{}]), + output_dir=tmp_path, + num_runs=1, + dry_run=False, + resume=None, + link_vars=[], + server_ready_timeout=1, + search_space={}, + n_trials=1, + direction="maximize", + score_metric="total_token_throughput", + score_concurrencies=[1], + baseline_params=ParameterSweepItem(), + fixed_serve_overrides=ParameterSweepItem(), + fixed_bench_overrides=ParameterSweepItem(), + study_name="test-study", + sampler_seed=0, + start_best_server=True, + ) + + best_record = run_main(args) + assert best_record is not None + assert start_calls["count"] == 1 + + +def test_evaluate_configuration_sets_num_prompts_from_concurrency( + tmp_path, + monkeypatch, +): + captured_bench_overrides: list[dict[str, object]] = [] + + @contextlib.contextmanager + def mock_run_server(*args, **kwargs): # noqa: ARG001 + yield object() + + def mock_run_benchmark( + server, # noqa: ARG001 + bench_cmd, # noqa: ARG001 + *, + serve_overrides, # noqa: ARG001 + bench_overrides, + run_number, # noqa: ARG001 + output_path, # noqa: ARG001 + dry_run, # noqa: ARG001 + ): + captured_bench_overrides.append(dict(bench_overrides)) + return {"total_token_throughput": 100.0} + + monkeypatch.setattr("vllm.benchmarks.sweep.serve_optuna.run_server", mock_run_server) + monkeypatch.setattr( + "vllm.benchmarks.sweep.serve_optuna.run_benchmark", + mock_run_benchmark, + ) + + result = serve_optuna.evaluate_configuration( + serve_cmd=["vllm", "serve", "Qwen/Qwen3-0.6B"], + bench_cmd=["vllm", "bench", "serve", "--model", "Qwen/Qwen3-0.6B"], + after_bench_cmd=[], + show_stdout=False, + dry_run=False, + server_ready_timeout=1, + serve_overrides=ParameterSweepItem(), + bench_overrides=ParameterSweepItem({"request_rate": 1.0}), + score_metric="total_token_throughput", + score_concurrencies=[1, 4], + num_runs=1, + output_dir=tmp_path, + ) + + assert result is not None + assert len(captured_bench_overrides) == 2 + assert captured_bench_overrides[0]["max_concurrency"] == 1 + assert captured_bench_overrides[0]["num_prompts"] == 5 + assert captured_bench_overrides[1]["max_concurrency"] == 4 + assert captured_bench_overrides[1]["num_prompts"] == 20 + + +def test_run_main_starts_best_server_with_effective_sanitized_params( + tmp_path, + monkeypatch, +): + class FakeTrialState: + COMPLETE = "complete" + + class FakeTrial: + def __init__(self, number: int): + self.number = number + self.params: dict[str, object] = {} + self.user_attrs: dict[str, object] = {} + self.state = FakeTrialState.COMPLETE + self.value: float | None = None + + def suggest_categorical(self, name, choices): # noqa: ARG002 + value = 1024 + self.params[name] = value + return value + + def set_user_attr(self, key: str, value: object) -> None: + self.user_attrs[key] = value + + class FakeStudy: + def __init__(self, direction: str): # noqa: ARG002 + self.trials: list[FakeTrial] = [] + self.best_trial: FakeTrial | None = None + + def optimize(self, objective, n_trials: int): # noqa: ARG002 + trial = FakeTrial(0) + trial.value = objective(trial) + self.trials = [trial] + self.best_trial = trial + + fake_optuna = SimpleNamespace( + samplers=SimpleNamespace(TPESampler=lambda seed=None: object()), # noqa: ARG005 + trial=SimpleNamespace(TrialState=FakeTrialState), + TrialPruned=RuntimeError, + create_study=lambda **kwargs: FakeStudy(kwargs["direction"]), + ) + monkeypatch.setattr(serve_optuna, "optuna", fake_optuna) + + def mock_evaluate_configuration(*args, **kwargs): # noqa: ARG001 + return 1.0, {"score": 1.0, "runs": []} + + monkeypatch.setattr( + "vllm.benchmarks.sweep.serve_optuna.evaluate_configuration", + mock_evaluate_configuration, + ) + + start_call: dict[str, object] = {} + + def mock_start_best_server(serve_cmd, serve_overrides, **kwargs): # noqa: ARG001 + start_call["serve_cmd"] = serve_cmd + start_call["serve_overrides"] = dict(serve_overrides) + return 12345 + + monkeypatch.setattr( + "vllm.benchmarks.sweep.serve_optuna._start_best_server", + mock_start_best_server, + ) + + args = SweepServeOptunaArgs( + serve_cmd=[ + "vllm", + "serve", + "Qwen/Qwen3-0.6B", + "--max-model-len", + "40960", + ], + bench_cmd=["vllm", "bench", "serve", "--model", "Qwen/Qwen3-0.6B"], + after_bench_cmd=[], + show_stdout=False, + serve_params=ParameterSweep.from_records([{}]), + bench_params=ParameterSweep.from_records([{}]), + output_dir=tmp_path, + num_runs=1, + dry_run=False, + resume=None, + link_vars=[], + server_ready_timeout=1, + search_space={ + "max_num_batched_tokens": {"type": "categorical", "choices": [1024]} + }, + n_trials=1, + direction="maximize", + score_metric="total_token_throughput", + score_concurrencies=[1], + baseline_params=ParameterSweepItem(), + fixed_serve_overrides=ParameterSweepItem(), + fixed_bench_overrides=ParameterSweepItem(), + study_name="test-study", + sampler_seed=0, + start_best_server=True, + ) + + best_record = run_main(args) + + assert best_record is not None + assert best_record["params"]["max_num_batched_tokens"] is None + assert "serve_overrides" in start_call + # None-valued overrides must not be emitted into the final serve command. + assert "max_num_batched_tokens" not in start_call["serve_overrides"] diff --git a/tests/entrypoints/test_serve_optuna_cli_alias.py b/tests/entrypoints/test_serve_optuna_cli_alias.py new file mode 100644 index 000000000000..5faeba98a829 --- /dev/null +++ b/tests/entrypoints/test_serve_optuna_cli_alias.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.entrypoints.cli.serve_optuna import ServeOptunaSubcommand +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def _build_parser() -> tuple[FlexibleArgumentParser, ServeOptunaSubcommand]: + parser = FlexibleArgumentParser(description="vLLM CLI test parser") + subparsers = parser.add_subparsers(required=True, dest="subparser") + cmd = ServeOptunaSubcommand() + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) + return parser, cmd + + +def test_serve_optuna_alias_dispatches_to_sweep_entrypoint(monkeypatch, tmp_path): + calls = {"count": 0} + + def _fake_main(args): + calls["count"] += 1 + calls["subparser"] = args.subparser + + monkeypatch.setattr( + "vllm.entrypoints.cli.serve_optuna.serve_optuna_main", + _fake_main, + ) + + search_space_file = tmp_path / "search_space.json" + search_space_file.write_text("{}", encoding="utf-8") + + parser, _ = _build_parser() + args = parser.parse_args( + [ + "serve-optuna", + "--serve-cmd", + "vllm serve Qwen/Qwen3-0.6B", + "--bench-cmd", + "vllm bench serve --model Qwen/Qwen3-0.6B", + "--search-space", + str(search_space_file), + "--dry-run", + ] + ) + args.dispatch_function(args) + + assert calls["count"] == 1 + assert calls["subparser"] == "serve-optuna" + + +def test_serve_optuna_underscore_alias_is_supported(tmp_path): + search_space_file = tmp_path / "search_space.json" + search_space_file.write_text("{}", encoding="utf-8") + + parser, _ = _build_parser() + args = parser.parse_args( + [ + "serve_optuna", + "--serve-cmd", + "vllm serve Qwen/Qwen3-0.6B", + "--bench-cmd", + "vllm bench serve --model Qwen/Qwen3-0.6B", + "--search-space", + str(search_space_file), + "--dry-run", + ] + ) + + assert args.subparser in {"serve-optuna", "serve_optuna"} diff --git a/vllm/benchmarks/sweep/cli.py b/vllm/benchmarks/sweep/cli.py index 75549105fa97..95cf7d0959da 100644 --- a/vllm/benchmarks/sweep/cli.py +++ b/vllm/benchmarks/sweep/cli.py @@ -8,6 +8,8 @@ from .plot import main as plot_main from .plot_pareto import SweepPlotParetoArgs from .plot_pareto import main as plot_pareto_main +from .serve_optuna import SweepServeOptunaArgs +from .serve_optuna import main as serve_optuna_main from .serve import SweepServeArgs from .serve import main as serve_main from .serve_workload import SweepServeWorkloadArgs @@ -18,6 +20,7 @@ SUBCOMMANDS = ( (SweepServeArgs, serve_main), (SweepServeWorkloadArgs, serve_workload_main), + (SweepServeOptunaArgs, serve_optuna_main), (SweepStartupArgs, startup_main), (SweepPlotArgs, plot_main), (SweepPlotParetoArgs, plot_pareto_main), diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py index f64006ee1023..e80ead97e78e 100644 --- a/vllm/benchmarks/sweep/serve.py +++ b/vllm/benchmarks/sweep/serve.py @@ -323,7 +323,11 @@ class SweepServeArgs: @classmethod def from_cli_args(cls, args: argparse.Namespace): serve_cmd = shlex.split(args.serve_cmd) - bench_cmd = shlex.split(args.bench_cmd) + bench_cmd = ( + ["vllm", "bench", "serve"] + if args.bench_cmd is None + else shlex.split(args.bench_cmd) + ) after_bench_cmd = ( [] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd) ) @@ -378,8 +382,11 @@ def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParse parser.add_argument( "--bench-cmd", type=str, - required=True, - help="The command used to run the benchmark: `vllm bench serve ...`", + default=None, + help=( + "The command used to run the benchmark: `vllm bench serve ...`. " + "If omitted, defaults to `vllm bench serve`." + ), ) parser.add_argument( "--after-bench-cmd", diff --git a/vllm/benchmarks/sweep/serve_optuna.py b/vllm/benchmarks/sweep/serve_optuna.py new file mode 100644 index 000000000000..cb6cdb19fcd2 --- /dev/null +++ b/vllm/benchmarks/sweep/serve_optuna.py @@ -0,0 +1,815 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import json +import copy +import os +import subprocess +import time +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, ClassVar + +import requests + +from vllm.utils.import_utils import PlaceholderModule + +from .param_sweep import ParameterSweepItem +from .serve import SweepServeArgs, run_benchmark, run_server + +try: + import optuna +except ImportError: + optuna = PlaceholderModule("optuna") + + +DEFAULT_VLLM_SEARCH_SPACE: dict[str, Any] = { + "gpu_memory_utilization": { + "type": "float", + "low": 0.7, + "high": 0.98, + "step": 0.02, + }, + "max_num_batched_tokens": { + "type": "categorical", + "choices": [None, 512, 1024, 2048, 4096, 8192], + }, + "max_num_seqs": { + "type": "categorical", + "choices": [None, 8, 16, 32, 64, 128, 256], + }, + "enable_chunked_prefill": {"type": "bool"}, + "enable_prefix_caching": {"type": "bool"}, +} + + +def _require_optuna() -> None: + if isinstance(optuna, PlaceholderModule): + raise ImportError( + "Please install optuna to use `vllm bench sweep serve_optuna`." + ) + + +def parse_score_concurrencies(raw: str) -> list[int]: + values = [int(part.strip()) for part in raw.split(",") if part.strip()] + if not values: + raise ValueError("score_concurrencies cannot be empty") + if any(value <= 0 for value in values): + raise ValueError("score_concurrencies must be > 0") + return values + + +def read_search_space(path: str) -> dict[str, Any]: + with open(path, "rb") as f: + loaded = json.load(f) + + if not isinstance(loaded, dict): + raise TypeError("search space must be a JSON object") + + return loaded + + +def default_search_space() -> dict[str, Any]: + # Return a deep copy so callers can safely mutate the object. + return copy.deepcopy(DEFAULT_VLLM_SEARCH_SPACE) + + +def _extract_arg_value(cmd: list[str], flag: str) -> str | None: + for i, token in enumerate(cmd): + if token == flag and i + 1 < len(cmd): + return cmd[i + 1] + if token.startswith(flag + "="): + return token.split("=", 1)[1] + return None + + +def _has_flag(cmd: list[str], flag: str) -> bool: + if flag in cmd: + return True + return any(token.startswith(flag + "=") for token in cmd) + + +def _extract_model_path_from_serve_cmd(serve_cmd: list[str]) -> str | None: + model_arg = _extract_arg_value(serve_cmd, "--model") + if model_arg: + return model_arg + + for i, token in enumerate(serve_cmd): + if token == "serve" and i + 1 < len(serve_cmd): + candidate = serve_cmd[i + 1] + if not candidate.startswith("-"): + return candidate + return None + + +def _read_model_max_len_from_config(model_path: str) -> int | None: + config_path = Path(model_path) / "config.json" + if not config_path.is_file(): + return None + + try: + with config_path.open("rb") as f: + loaded = json.load(f) + except (OSError, json.JSONDecodeError): + return None + + if not isinstance(loaded, dict): + return None + + candidate_keys = ( + "max_model_len", + "max_position_embeddings", + "n_positions", + "seq_length", + "max_sequence_length", + "model_max_length", + ) + + for key in candidate_keys: + value = loaded.get(key) + if isinstance(value, int) and value > 0: + return value + + text_config = loaded.get("text_config") + if isinstance(text_config, dict): + for key in candidate_keys: + value = text_config.get(key) + if isinstance(value, int) and value > 0: + return value + + return None + + +def _infer_max_model_len_from_serve_cmd(serve_cmd: list[str]) -> int | None: + max_model_len_arg = _extract_arg_value(serve_cmd, "--max-model-len") + if max_model_len_arg is not None: + try: + parsed = int(max_model_len_arg) + if parsed > 0: + return parsed + except ValueError: + pass + + model_path = _extract_model_path_from_serve_cmd(serve_cmd) + if not model_path: + return None + + return _read_model_max_len_from_config(model_path) + + +def _extract_served_model_name_from_serve_cmd(serve_cmd: list[str]) -> str | None: + value = _extract_arg_value(serve_cmd, "--served-model-name") + if value: + return value + return _extract_model_path_from_serve_cmd(serve_cmd) + + +def _extract_server_base_url_from_serve_cmd(serve_cmd: list[str]) -> str: + host = _extract_arg_value(serve_cmd, "--host") or "127.0.0.1" + port = _extract_arg_value(serve_cmd, "--port") or _extract_arg_value(serve_cmd, "-p") + if port is None: + port = "8000" + + # `vllm serve` often uses 0.0.0.0, but benchmark client should target localhost. + if host in ("0.0.0.0", "::"): + host = "127.0.0.1" + return f"http://{host}:{port}" + + +def _apply_default_bench_cmd_fields( + bench_cmd: list[str], + serve_cmd: list[str], +) -> list[str]: + cmd = list(bench_cmd) + if len(cmd) >= 3 and cmd[:3] == ["vllm", "bench", "serve"]: + pass + elif len(cmd) == 0: + cmd = ["vllm", "bench", "serve"] + + model_name = _extract_served_model_name_from_serve_cmd(serve_cmd) + model_path = _extract_model_path_from_serve_cmd(serve_cmd) + base_url = _extract_server_base_url_from_serve_cmd(serve_cmd) + + if model_name and not _has_flag(cmd, "--model"): + cmd.extend(["--model", model_name]) + if model_path and not _has_flag(cmd, "--tokenizer"): + cmd.extend(["--tokenizer", model_path]) + if not _has_flag(cmd, "--base-url"): + cmd.extend(["--base-url", base_url]) + if not _has_flag(cmd, "--backend"): + cmd.extend(["--backend", "openai"]) + if not _has_flag(cmd, "--endpoint"): + cmd.extend(["--endpoint", "/v1/completions"]) + return cmd + + +def _sanitize_serve_trial_params( + trial_params: dict[str, Any], + inferred_max_model_len: int | None, +) -> dict[str, Any]: + sanitized = dict(trial_params) + + if inferred_max_model_len is None: + return sanitized + + max_num_batched_tokens = sanitized.get("max_num_batched_tokens") + if ( + isinstance(max_num_batched_tokens, int) + and max_num_batched_tokens < inferred_max_model_len + ): + sanitized["max_num_batched_tokens"] = None + + return sanitized + + +def read_single_record(path: str | None) -> ParameterSweepItem: + if path is None: + return ParameterSweepItem() + + with open(path, "rb") as f: + loaded = json.load(f) + + if isinstance(loaded, list): + if len(loaded) != 1: + raise ValueError("baseline_params JSON list must contain exactly one object") + loaded = loaded[0] + + if not isinstance(loaded, dict): + raise TypeError("baseline_params must be a JSON object") + + return ParameterSweepItem.from_record(loaded) + + +def suggest_trial_params(trial: Any, search_space: dict[str, Any]) -> dict[str, Any]: + suggested: dict[str, Any] = {} + + for key, spec in search_space.items(): + if not isinstance(spec, dict) or "type" not in spec: + suggested[key] = spec + continue + + dist_type = str(spec["type"]) + if dist_type == "categorical": + choices = spec.get("choices") + if not isinstance(choices, list) or not choices: + raise ValueError(f"search space key '{key}' requires non-empty choices") + suggested[key] = trial.suggest_categorical(key, choices) + continue + + if dist_type == "bool": + suggested[key] = trial.suggest_categorical(key, [True, False]) + continue + + if dist_type == "int": + low = int(spec["low"]) + high = int(spec["high"]) + step = int(spec.get("step", 1)) + log = bool(spec.get("log", False)) + suggested[key] = trial.suggest_int(key, low, high, step=step, log=log) + continue + + if dist_type == "float": + low = float(spec["low"]) + high = float(spec["high"]) + step = spec.get("step") + if step is not None: + step = float(step) + log = bool(spec.get("log", False)) + if step is not None and log: + raise ValueError( + f"search space key '{key}' cannot set both step and log for float" + ) + suggested[key] = trial.suggest_float(key, low, high, step=step, log=log) + continue + + raise ValueError(f"unsupported distribution type for '{key}': {dist_type}") + + return suggested + + +def _default_params_from_search_space(search_space: dict[str, Any]) -> dict[str, Any]: + defaults: dict[str, Any] = {} + for key, spec in search_space.items(): + if not isinstance(spec, dict) or "type" not in spec: + defaults[key] = spec + continue + + dist_type = str(spec["type"]) + if dist_type == "categorical": + choices = spec.get("choices") + if not isinstance(choices, list) or not choices: + raise ValueError(f"search space key '{key}' requires non-empty choices") + defaults[key] = choices[0] + elif dist_type == "bool": + defaults[key] = False + elif dist_type == "int": + defaults[key] = int(spec["low"]) + elif dist_type == "float": + defaults[key] = float(spec["low"]) + else: + raise ValueError(f"unsupported distribution type for '{key}': {dist_type}") + + return defaults + + +def _drop_none_values(params: dict[str, Any] | ParameterSweepItem) -> ParameterSweepItem: + return ParameterSweepItem({k: v for k, v in dict(params).items() if v is not None}) + + +def _start_best_server( + serve_cmd: list[str], + serve_overrides: ParameterSweepItem, + *, + show_stdout: bool, + server_ready_timeout: int, +) -> int: + best_server_cmd = serve_overrides.apply_to_cmd(serve_cmd) + print("[START BEST SERVER]") + print(f"Best server command: {best_server_cmd}") + + process = subprocess.Popen( + best_server_cmd, + start_new_session=True, + stdout=None if show_stdout else subprocess.DEVNULL, + stderr=None if show_stdout else subprocess.DEVNULL, + env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"}, + ) + + start_time = time.monotonic() + server_address = _extract_server_base_url_from_serve_cmd(best_server_cmd) + health_url = server_address + "/health" + while True: + if process.poll() is not None: + raise RuntimeError( + f"Best server process crashed with return code {process.returncode}" + ) + try: + response = requests.get(health_url, timeout=3) + if response.status_code == 200: + print(f"Best server is ready at {server_address} (pid={process.pid})") + return process.pid + except requests.RequestException: + pass + + if time.monotonic() - start_time > server_ready_timeout: + process.kill() + raise TimeoutError( + f"Best server failed to become ready within {server_ready_timeout} seconds" + ) + time.sleep(1) + + +def score_benchmark_runs( + run_data: list[dict[str, object]], + score_metric: str, + concurrency: int, +) -> tuple[float, float]: + metric_values = list[float]() + for run in run_data: + if score_metric not in run: + raise KeyError(f"benchmark output missing metric '{score_metric}'") + metric_values.append(float(run[score_metric])) + + if not metric_values: + raise RuntimeError("benchmark output is empty") + + metric_mean = sum(metric_values) / len(metric_values) + return metric_mean / float(concurrency), metric_mean + + +def evaluate_configuration( + serve_cmd: list[str], + bench_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + dry_run: bool, + server_ready_timeout: int, + serve_overrides: ParameterSweepItem, + bench_overrides: ParameterSweepItem, + score_metric: str, + score_concurrencies: list[int], + num_runs: int, + output_dir: Path, +) -> tuple[float, dict[str, Any]] | None: + with run_server( + serve_cmd, + after_bench_cmd, + show_stdout=show_stdout, + serve_overrides=serve_overrides, + dry_run=dry_run, + server_ready_timeout=server_ready_timeout, + ) as server: + run_summaries = list[dict[str, Any]]() + total_score = 0.0 + saw_dry_run_output = False + + for concurrency in score_concurrencies: + benchmark_runs = list[dict[str, object]]() + merged_bench_overrides = bench_overrides | { + "max_concurrency": concurrency, + "num_prompts": concurrency * 5, + } + + for run_number in range(num_runs): + run_data = run_benchmark( + server, + bench_cmd, + serve_overrides=serve_overrides, + bench_overrides=merged_bench_overrides, + run_number=run_number, + output_path=( + output_dir + / f"concurrency={concurrency}" + / f"run={run_number}.json" + ), + dry_run=dry_run, + ) + if run_data is None: + assert dry_run + saw_dry_run_output = True + continue + benchmark_runs.append(run_data) + + if saw_dry_run_output: + continue + + normalized_score, metric_mean = score_benchmark_runs( + benchmark_runs, + score_metric, + concurrency, + ) + total_score += normalized_score + run_summaries.append( + { + "concurrency": concurrency, + "metric": score_metric, + "metric_mean": metric_mean, + "normalized_score": normalized_score, + "runs": benchmark_runs, + } + ) + + if saw_dry_run_output: + return None + + return total_score, { + "score_formula": f"sum(mean({score_metric}) / concurrency)", + "score_metric": score_metric, + "score_concurrencies": score_concurrencies, + "runs": run_summaries, + "score": total_score, + } + + +def append_trial_record(file_path: Path, record: dict[str, Any]) -> None: + try: + with file_path.open("rb") as f: + loaded = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + loaded = [] + + if not isinstance(loaded, list): + loaded = [] + + loaded.append(record) + + with file_path.open("w") as f: + json.dump(loaded, f, indent=4) + + +@dataclass +class SweepServeOptunaArgs(SweepServeArgs): + search_space: dict[str, Any] + n_trials: int + direction: str + score_metric: str + score_concurrencies: list[int] + baseline_params: ParameterSweepItem + fixed_serve_overrides: ParameterSweepItem + fixed_bench_overrides: ParameterSweepItem + study_name: str | None + sampler_seed: int | None + start_best_server: bool + + parser_name: ClassVar[str] = "serve_optuna" + parser_help: ClassVar[str] = "Tune serve parameters with Optuna." + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + base_args = SweepServeArgs.from_cli_args(args) + base_args_dict = asdict(base_args) + + if len(base_args.serve_params) != 1: + raise ValueError( + "serve_optuna supports exactly one fixed serve_params entry. " + "Use --search-space to tune values." + ) + if len(base_args.bench_params) != 1: + raise ValueError( + "serve_optuna supports exactly one fixed bench_params entry. " + "Use --score-concurrencies for benchmark concurrency sweep." + ) + if base_args.link_vars: + raise ValueError("serve_optuna does not support --link-vars") + + if args.n_trials < 1: + raise ValueError("n_trials should be at least 1") + + base_args_dict["bench_cmd"] = _apply_default_bench_cmd_fields( + base_args.bench_cmd, + base_args.serve_cmd, + ) + + return cls( + **base_args_dict, + search_space=( + default_search_space() + if args.search_space is None + else read_search_space(args.search_space) + ), + n_trials=args.n_trials, + direction=args.direction, + score_metric=args.score_metric, + score_concurrencies=parse_score_concurrencies(args.score_concurrencies), + baseline_params=read_single_record(args.baseline_params), + fixed_serve_overrides=base_args.serve_params[0], + fixed_bench_overrides=base_args.bench_params[0], + study_name=args.study_name, + sampler_seed=args.sampler_seed, + start_best_server=args.start_best_server, + ) + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = super().add_cli_args(parser) + parser.set_defaults(num_runs=1) + + optuna_group = parser.add_argument_group("optuna options") + optuna_group.add_argument( + "--search-space", + type=str, + default=None, + help=( + "Optional path to JSON object defining trial parameters. " + "Each key supports one of: " + "{\"type\":\"float\",\"low\":...,\"high\":...,\"step\":...}, " + "{\"type\":\"int\",\"low\":...,\"high\":...,\"step\":...}, " + "{\"type\":\"categorical\",\"choices\":[...]}, " + "{\"type\":\"bool\"}. " + "Non-object values are treated as constants. " + "If omitted, uses built-in vLLM defaults for common " + "serve tuning knobs." + ), + ) + optuna_group.add_argument( + "--n-trials", + type=int, + default=20, + help="Number of Optuna trials.", + ) + optuna_group.add_argument( + "--direction", + type=str, + choices=("maximize", "minimize"), + default="maximize", + help="Optimization direction for score.", + ) + optuna_group.add_argument( + "--score-metric", + type=str, + default="total_token_throughput", + help="Metric key read from each benchmark result JSON.", + ) + optuna_group.add_argument( + "--score-concurrencies", + type=str, + default="1,8,64,256", + help="Comma-separated concurrency list used for scoring.", + ) + optuna_group.add_argument( + "--baseline-params", + type=str, + default=None, + help=( + "Optional path to JSON object of serve overrides for baseline run. " + "If omitted, baseline uses the base serve command plus fixed serve_params." + ), + ) + optuna_group.add_argument( + "--study-name", + type=str, + default=None, + help="Optional Optuna study name.", + ) + optuna_group.add_argument( + "--sampler-seed", + type=int, + default=None, + help="Optional random seed for Optuna sampler.", + ) + optuna_group.add_argument( + "--start-best-server", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Whether to launch `serve_cmd` with the best Optuna parameters " + "after optimization." + ), + ) + + return parser + + +def run_main(args: SweepServeOptunaArgs): + _require_optuna() + inferred_max_model_len = _infer_max_model_len_from_serve_cmd(args.serve_cmd) + + timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = args.output_dir / timestamp + + if args.resume and not output_dir.exists(): + raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") + + output_dir.mkdir(parents=True, exist_ok=True) + + baseline_file = output_dir / "baseline.json" + trials_file = output_dir / "trials.json" + best_params_file = output_dir / "best_params.json" + best_file = output_dir / "best.json" + + baseline_overrides = _drop_none_values(args.fixed_serve_overrides | args.baseline_params) + + baseline_result = evaluate_configuration( + args.serve_cmd, + args.bench_cmd, + args.after_bench_cmd, + show_stdout=args.show_stdout, + dry_run=args.dry_run, + server_ready_timeout=args.server_ready_timeout, + serve_overrides=baseline_overrides, + bench_overrides=args.fixed_bench_overrides, + score_metric=args.score_metric, + score_concurrencies=args.score_concurrencies, + num_runs=args.num_runs, + output_dir=output_dir / "baseline_runs", + ) + + if baseline_result is None: + assert args.dry_run + preview_overrides = _drop_none_values( + args.fixed_serve_overrides | _default_params_from_search_space(args.search_space) + ) + evaluate_configuration( + args.serve_cmd, + args.bench_cmd, + args.after_bench_cmd, + show_stdout=args.show_stdout, + dry_run=True, + server_ready_timeout=args.server_ready_timeout, + serve_overrides=preview_overrides, + bench_overrides=args.fixed_bench_overrides, + score_metric=args.score_metric, + score_concurrencies=args.score_concurrencies, + num_runs=args.num_runs, + output_dir=output_dir / "sample_trial", + ) + return None + + baseline_score, baseline_payload = baseline_result + baseline_record = { + "params": dict(baseline_overrides), + "score": baseline_score, + "result": baseline_payload, + } + with baseline_file.open("w") as f: + json.dump(baseline_record, f, indent=4) + + append_trial_record( + trials_file, + { + "trial": -1, + "state": "baseline", + "params": dict(baseline_overrides), + "score": baseline_score, + "result": baseline_payload, + }, + ) + + sampler = optuna.samplers.TPESampler(seed=args.sampler_seed) + study = optuna.create_study( + study_name=args.study_name, + direction=args.direction, + sampler=sampler, + ) + + def objective(trial: Any) -> float: + trial_params = _sanitize_serve_trial_params( + suggest_trial_params(trial, args.search_space), + inferred_max_model_len, + ) + trial.set_user_attr("effective_serve_params", trial_params) + serve_overrides = _drop_none_values(args.fixed_serve_overrides | trial_params) + + try: + evaluated = evaluate_configuration( + args.serve_cmd, + args.bench_cmd, + args.after_bench_cmd, + show_stdout=args.show_stdout, + dry_run=False, + server_ready_timeout=args.server_ready_timeout, + serve_overrides=serve_overrides, + bench_overrides=args.fixed_bench_overrides, + score_metric=args.score_metric, + score_concurrencies=args.score_concurrencies, + num_runs=args.num_runs, + output_dir=output_dir / f"trial={trial.number}", + ) + if evaluated is None: + raise RuntimeError("unexpected dry-run state during optimization") + + score, payload = evaluated + trial.set_user_attr("benchmark_result", payload) + append_trial_record( + trials_file, + { + "trial": trial.number, + "state": "complete", + "params": trial_params, + "score": score, + "result": payload, + }, + ) + return score + except BaseException as exc: + append_trial_record( + trials_file, + { + "trial": trial.number, + "state": "pruned", + "params": trial_params, + "score": None, + "error": str(exc), + }, + ) + raise optuna.TrialPruned() from exc + + study.optimize(objective, n_trials=args.n_trials) + + completed_trials = [ + trial + for trial in study.trials + if trial.state == optuna.trial.TrialState.COMPLETE + ] + if not completed_trials: + raise RuntimeError("No completed trials. All trials were pruned.") + + best_trial = study.best_trial + best_payload = best_trial.user_attrs.get("benchmark_result") + if not isinstance(best_payload, dict): + raise RuntimeError("best trial is missing benchmark result payload") + + effective_best_params = best_trial.user_attrs.get("effective_serve_params") + if not isinstance(effective_best_params, dict): + effective_best_params = _sanitize_serve_trial_params( + dict(best_trial.params), + inferred_max_model_len, + ) + + with best_params_file.open("w") as f: + json.dump(effective_best_params, f, indent=4) + + best_record = { + "trial": best_trial.number, + "params": effective_best_params, + "score": best_trial.value, + "result": best_payload, + } + with best_file.open("w") as f: + json.dump(best_record, f, indent=4) + + if args.start_best_server: + effective_best_overrides = _drop_none_values( + args.fixed_serve_overrides | effective_best_params + ) + _start_best_server( + args.serve_cmd, + effective_best_overrides, + show_stdout=args.show_stdout, + server_ready_timeout=args.server_ready_timeout, + ) + + return best_record + + +def main(args: argparse.Namespace): + run_main(SweepServeOptunaArgs.from_cli_args(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=SweepServeOptunaArgs.parser_help) + SweepServeOptunaArgs.add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 2261ef233134..2478dcb27059 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -20,6 +20,7 @@ def main(): import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.run_batch import vllm.entrypoints.cli.serve + import vllm.entrypoints.cli.serve_optuna from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -27,6 +28,7 @@ def main(): vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, vllm.entrypoints.cli.launch, + vllm.entrypoints.cli.serve_optuna, vllm.entrypoints.cli.benchmark.main, vllm.entrypoints.cli.collect_env, vllm.entrypoints.cli.run_batch, @@ -34,10 +36,11 @@ def main(): cli_env_setup() - # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default - if len(sys.argv) > 1 and sys.argv[1] == "bench": + # For benchmark-style commands: use CPU instead of + # UnspecifiedPlatform by default. + if len(sys.argv) > 1 and sys.argv[1] in ("bench", "serve-optuna", "serve_optuna"): logger.debug( - "Bench command detected, must ensure current platform is not " + "Benchmark command detected, must ensure current platform is not " "UnspecifiedPlatform to avoid device type inference error" ) from vllm import platforms diff --git a/vllm/entrypoints/cli/serve_optuna.py b/vllm/entrypoints/cli/serve_optuna.py new file mode 100644 index 000000000000..ac5b939247b9 --- /dev/null +++ b/vllm/entrypoints/cli/serve_optuna.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import typing + +from vllm.benchmarks.sweep.serve_optuna import SweepServeOptunaArgs +from vllm.benchmarks.sweep.serve_optuna import main as serve_optuna_main +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG + +if typing.TYPE_CHECKING: + from vllm.utils.argparse_utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = argparse.ArgumentParser + + +class ServeOptunaSubcommand(CLISubcommand): + """The `serve-optuna` top-level subcommand for the vLLM CLI.""" + + name = "serve-optuna" + help = "Tune vLLM serve parameters with Optuna." + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + serve_optuna_main(args) + + def subparser_init( + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: + parser = subparsers.add_parser( + self.name, + aliases=["serve_optuna"], + help=self.help, + description=self.help, + usage=f"vllm {self.name} [options]", + ) + SweepServeOptunaArgs.add_cli_args(parser) + parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) + return parser + + +def cmd_init() -> list[CLISubcommand]: + return [ServeOptunaSubcommand()] diff --git a/vllm/entrypoints/unieai.py b/vllm/entrypoints/unieai.py new file mode 100644 index 000000000000..60c10274193a --- /dev/null +++ b/vllm/entrypoints/unieai.py @@ -0,0 +1,421 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import cmd +import sys +import os +import base64 +import json +import hashlib +import asyncio +import pathlib +import subprocess +import urllib.request +import urllib.error +import shlex +from argparse import ArgumentParser + +from datetime import datetime, timezone + +try: + import uvloop +except ImportError: + uvloop = None + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding + +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +# from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.logger import init_logger + +logger = init_logger(__name__) + +LICENSE_SERVER = { + "host": [ + "https://auth.unieai.com", + "https://uls.unieai.com", + "https://13.114.141.202", + "http://13.114.141.202", + ], + "info": """-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA5Xtc4qwU2nxODyg3h2i8 +2wMYofSUA9ZKpjaaLE1sbH8gJrij5KxSAShtLgq9I5O6FKtLfA+OVdYJnM7TzMS7 +DIIgxabp3+x4NEbhUW2zi2Z4sX2eUFHlSDcy6xNoi6txk9KpHMKYt0QtHL7XGJPN +lULIvG5zwDTbJY3MpYiJW27U8qCbCGq/gnV3mva1NLNjL0vqTeiUQgPiwYakEPuJ +H0Yt5exYueMltRoTxRIOq2uK6KPJiQu0f9m1u/J3PXoTZN4WyySXealneN95wfeF +InxZBNLEBVHnJ1adWSAcmIdPLvljDixpMt57OPUa7dEDXO6e5mKF9aj9HcAER8BC +lQIDAQAB +-----END PUBLIC KEY-----""" +} + +CACHE_DIR = pathlib.Path(os.environ.get("UNIEAI_CACHE_DIR", str(pathlib.Path.home() / ".unieai"))) +CACHE_FILE = CACHE_DIR / "last_verified.json" + +# --------------------------------------------------------------------------- +# Device fingerprint +# --------------------------------------------------------------------------- + +def get_session_id() -> str: + """Derive a deterministic session ID from NVIDIA GPU UUIDs. + + Runs ``nvidia-smi --query-gpu=uuid --format=csv,noheader``, collects all + GPU UUIDs, sorts them alphabetically, concatenates them, and returns the + SHA-256 hex digest as the session ID. + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=uuid", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + logger.error("nvidia-smi failed (rc=%d): %s", result.returncode, result.stderr.strip()) + sys.exit(1) + + uuids = sorted(line.strip() for line in result.stdout.strip().splitlines() if line.strip()) + if not uuids: + logger.error("nvidia-smi returned no GPU UUIDs.") + sys.exit(1) + + combined = ",".join(uuids) + session_id = hashlib.sha256(combined.encode("utf-8")).hexdigest() + logger.info("Session ID derived from %d GPU(s): %s", len(uuids), session_id) + return session_id + + except FileNotFoundError: + logger.error("nvidia-smi not found. NVIDIA drivers must be installed.") + sys.exit(1) + except subprocess.TimeoutExpired: + logger.error("nvidia-smi timed out.") + sys.exit(1) + + +# --------------------------------------------------------------------------- +# License helpers +# --------------------------------------------------------------------------- + +def _load_public_key(): + """Load the embedded RSA public key.""" + pem_data = LICENSE_SERVER["info"].encode("utf-8") + return serialization.load_pem_public_key(pem_data) + + +def decrypt_license() -> dict: + """Read UNIEAI_LICENSE env var, decrypt with public key, return license data. + + The env var holds a base64-encoded blob containing: + { + "data": "", + "signature": "" + } + + The public key is used to verify the signature (i.e. "decrypt"), proving + the data was produced by the holder of the private key. On success the + inner license JSON is returned as a dict. Expected fields: + + license_key – e.g. "UNIE-TEST-001" + session_id – SHA-256 of sorted GPU UUIDs + expires_at – ISO-8601 date, e.g. "2026-12-31" + """ + raw = os.environ.get("UNIEAI_LICENSE") + if not raw: + logger.error("UNIEAI_LICENSE environment variable is not set.") + sys.exit(1) + + # ── Step 1: base64-decode the outer envelope ────────────────────────── + try: + envelope = json.loads(base64.b64decode(raw)) + except Exception as exc: + logger.error("Failed to decode UNIEAI_LICENSE: %s", exc) + sys.exit(1) + + data_b64: str | None = envelope.get("data") + sig_b64: str | None = envelope.get("signature") + + if not data_b64 or not sig_b64: + logger.error( + "UNIEAI_LICENSE envelope must contain 'data' and 'signature' fields." + ) + sys.exit(1) + + data_bytes = base64.b64decode(data_b64) + sig_bytes = base64.b64decode(sig_b64) + + # ── Step 2: verify RSA signature (PSS + SHA-256) ───────────────────── + try: + public_key = _load_public_key() + public_key.verify( + sig_bytes, + data_bytes, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH, + ), + hashes.SHA256(), + ) + logger.info("License decrypted and verified successfully.") + except Exception as exc: + logger.error("License verification failed (invalid signature): %s", exc) + sys.exit(1) + + # ── Step 3: parse the inner JSON ───────────────────────────────────── + try: + license_data = json.loads(data_bytes) + except Exception as exc: + logger.error("Failed to parse license data JSON: %s", exc) + sys.exit(1) + + required = ("license_key", "session_id", "expires_at") + missing = [k for k in required if k not in license_data] + if missing: + logger.error("License data is missing required fields: %s", missing) + sys.exit(1) + + logger.info( + "License loaded — key=%s, session=%s, expires=%s", + license_data["license_key"], + license_data["session_id"], + license_data["expires_at"], + ) + return license_data + +def check_online_server_availability() -> str: # return server host if reachable, else empty string + """Check if any license server host is reachable.""" + for host in LICENSE_SERVER["host"]: + url = f"{host}/api/health" + try: + with urllib.request.urlopen(url, timeout=5) as resp: + if resp.status == 200: + logger.info("License server is reachable at %s", host) + return host + except Exception as exc: + logger.debug("License server %s is unreachable: %s", host, exc) + continue + logger.warning("All license server hosts are unreachable.") + return "" + +# --------------------------------------------------------------------------- +# Online verification +# --------------------------------------------------------------------------- + +def verify_license_online(license_key: str, session_id: str, host: str) -> bool: + """Verify the license by POSTing to the server's heartbeat endpoint. + + Returns True if the server confirms ACTIVE + allowed, False if the server + rejects the license, or raises an exception if all servers are unreachable. + """ + if not host: + logger.warning("No license server available for online verification.") + return False + + payload = { + "license_key": license_key, + "machine_id": session_id, + } + # example payload in curl + # curl -X POST http://13.114.141.202/api/licenses/heartbeat \ + # -H "Content-Type: application/json" \ + # -d '{"license_key":"UNIE-TEST-001","session_id":"test-machine-abc123"}' + url = f"{host}/api/licenses/heartbeat" + try: + req = urllib.request.Request( + url, + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=10) as resp: + resp_body = json.loads(resp.read().decode("utf-8")) + + status = resp_body.get("status", "").upper() + allowed = resp_body.get("allowed", False) + + if status == "ACTIVE" and allowed: + logger.info( + "License heartbeat OK via %s — status=%s, allowed=%s", + host, status, allowed, + ) + _write_verified_cache(license_key, resp_body) + return True + else: + logger.warning("License rejected by %s — %s", host, resp_body) + return False + except Exception as exc: + logger.warning("License verification failed for %s: %s", host, exc) + return False + +def _write_verified_cache(license_key: str, server_response: dict): + """Persist the last successful verification to disk.""" + try: + CACHE_DIR.mkdir(parents=True, exist_ok=True) + cache = { + "license_key": license_key, + "verified_at": datetime.now(timezone.utc).isoformat(), + "server_response": server_response, + } + CACHE_FILE.write_text(json.dumps(cache, indent=2), encoding="utf-8") + logger.info("Last-verified cache written to %s", CACHE_FILE) + except Exception as exc: + logger.warning("Could not write verification cache: %s", exc) + + +def _read_verified_cache() -> dict | None: + """Read the last-verified cache, or None if unavailable.""" + try: + if CACHE_FILE.exists(): + return json.loads(CACHE_FILE.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning("Could not read verification cache: %s", exc) + return None + + +# --------------------------------------------------------------------------- +# Validity check +# --------------------------------------------------------------------------- + +def check_license_validity(license_data: dict, online_ok: bool) -> bool: + """Decide whether the license allows the server to launch. + + * If online verification succeeded → check ``expires_at`` is in the future. + * If all servers were unreachable → allow launch only if the cached + last-verified timestamp exists AND ``expires_at`` is still in the future. + """ + now = datetime.now(timezone.utc) + + # Parse expiry + try: + expires_str = license_data["expires_at"] + # Support both date-only ("2026-12-31") and full ISO datetime + if "T" in expires_str: + expires_at = datetime.fromisoformat(expires_str) + else: + expires_at = datetime.fromisoformat(expires_str + "T23:59:59+00:00") + # Ensure timezone-aware + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + except Exception as exc: + logger.error("Invalid expires_at in license data: %s", exc) + return False + + if expires_at <= now: + logger.error( + "License has expired (expires_at=%s, now=%s).", + expires_at.isoformat(), now.isoformat(), + ) + return False + + if online_ok: + logger.info("License is valid (online verified, expires %s).", expires_at.date()) + return True + + # Offline fallback — check cache + cache = _read_verified_cache() + if cache is None: + logger.error( + "All license servers are unreachable and no previous verification " + "cache exists. Cannot launch." + ) + return False + + logger.info( + "All license servers are unreachable, but license was last verified " + "at %s and does not expire until %s. Allowing launch.", + cache.get("verified_at", "unknown"), + expires_at.date(), + ) + return True + + +# --------------------------------------------------------------------------- +# Main entry-point +# --------------------------------------------------------------------------- + +def main(): + print("▗▖ ▗▖▗▖ ▗▖▗▄▄▄▖▗▄▄▄▖\033[94m ▗▄▖ ▗▄▄▄▖\033[0m") + print("▐▌ ▐▌▐▛▚▖▐▌ █ ▐▌ \033[94m▐▌ ▐▌ █\033[0m") + print("▐▌ ▐▌▐▌ ▝▜▌ █ ▐▛▀▀▘\033[94m▐▛▀▜▌ █\033[0m") + print("▝▚▄▞▘▐▌ ▐▌▗▄█▄▖▐▙▄▄▖\033[94m▐▌ ▐▌▗▄█▄▖\033[0m") + print() + print("▗▖ ▗▖▗▖ ▗▖▗▄▄▄▖▗▄▄▄▖\033[91m▗▄▄▄▖▗▖ ▗▖▗▄▄▄▖▗▄▄▖ ▗▄▖\033[0m") + print("▐▌ ▐▌▐▛▚▖▐▌ █ ▐▌ \033[91m █ ▐▛▚▖▐▌▐▌ ▐▌ ▐▌▐▌ ▐▌\033[0m") + print("▐▌ ▐▌▐▌ ▝▜▌ █ ▐▛▀▀▘\033[91m █ ▐▌ ▝▜▌▐▛▀▀▘▐▛▀▚▖▐▛▀▜▌\033[0m") + print("▝▚▄▞▘▐▌ ▐▌▗▄█▄▖▐▙▄▄▖\033[91m▗▄█▄▖▐▌ ▐▌▐▌ ▐▌ ▐▌▐▌ ▐▌\033[0m") + parser = ArgumentParser( + description="UnieInfra - UnieAI Licensed Inference Engine", + ) + + subparsers = parser.add_subparsers(dest="command", required=True) + + serve = subparsers.add_parser("serve", help="UnieInfra Launch Command") + serve.add_argument("model_name", help="The model tag to serve (optional if specified in config) (default: None)", default=None) + + unieconfig = subparsers.add_parser("unieconfig", help="Print license config and exit") + unieconfig.add_argument("model_name", help="The model tag to serve (optional if specified in config) (default: None)", default=None) + + # License check before doing anything + # mock license data + # { + # "license_key": "UNIE-TEST-001", + # "session_id": "test-session-abc123", # using GPU UUID to get session_id + # "expires_at": "2026-12-31T23:59:59+00:00" + # } + license_data = decrypt_license() + available_server = check_online_server_availability() + online_ok = verify_license_online(license_data["license_key"], license_data["session_id"], available_server) + + if not online_ok: + logger.info("Online license verification failed.") + + # Check offline license data is valid (e.g. not expired or GUP UUID mismatch) before allowing launch + + # Check expirey date + # exmaple license_data["expires_at"] = "2026/06/18" + expires_at_str = license_data["expires_at"] + try: + expires_at = datetime.fromisoformat(expires_at_str) + if expires_at <= datetime.now(timezone.utc): + logger.error("License has expired (expires_at=%s).", expires_at_str) + sys.exit(1) + except ValueError: + logger.error("Invalid expires_at format in license data: %s", expires_at_str) + sys.exit(1) + + if get_session_id() != license_data["session_id"]: + logger.error( + "GPU UUID mismatch: license session_id=%s, but current machine session_id=%s.", + license_data["session_id"], get_session_id(), + ) + sys.exit(1) + + + + args, unknown_args = parser.parse_known_args() + if args.command == "serve": + if "--easy" in unknown_args: + # remove --easy + unknown_args = [arg for arg in unknown_args if arg != "--easy"] + else: + unknown_args += ["--async-scheduling", "--speculative-config", '{"method":"ngram_dsc","num_speculative_tokens":4,"draft_tensor_parallel_size":1,"prompt_lookup_min":3,"prompt_lookup_max":8}'] + cmd = ["vllm", "serve", args.model_name] + unknown_args + # logger.info("Running UnieConfig with command: %s", " ".join(cmd)) + subprocess.run(cmd, check=True, shell=False, text=True) + if args.command == "unieconfig": + unieconfig_args = [ + "--score-concurrencies", "1,8,64,256", "--n-trials", "20", "-o", "benchmarks/results" + ] + unknown_cmd = " ".join(unknown_args) + serve_cmd = f"vllm serve {args.model_name} {unknown_cmd}" + cmd = [ + "vllm", + "serve-optuna", + "--serve-cmd", + shlex.quote(serve_cmd), + ] + unieconfig_args + # logger.info("Running UnieConfig with command: %s", " ".join(cmd)) + subprocess.run(" ".join(cmd), check=True, shell=True, text=True) + +if __name__ == "__main__": + main()