Skip to content

[Startup] Parallelize torch/transformers import + weight prefetch + forkserver prewarm#40331

Merged
chaunceyjiang merged 3 commits into
vllm-project:mainfrom
simon-mo:startup/parallelization
Apr 21, 2026
Merged

[Startup] Parallelize torch/transformers import + weight prefetch + forkserver prewarm#40331
chaunceyjiang merged 3 commits into
vllm-project:mainfrom
simon-mo:startup/parallelization

Conversation

@simon-mo
Copy link
Copy Markdown
Collaborator

Summary

Five commits that overlap previously-serial startup work with other startup-time work already in flight. All are behavior-preserving; nothing changes how the engine runs — only when, during the first ~20 seconds, specific pieces of work get done.

Cold-start speedup (back-to-back measurements, shared dev box, single GPU):

Model cold (main) cold (this PR) Δ cold warm (main) warm (this PR) Δ warm
Qwen2.5-0.5B-Instruct 78.93 s (σ=0.6%) 69.89 s (σ=7.7%) -9.0 s (-11.5%) 45.67 s (σ=3.4%) 42.65 s (σ=5.4%) -3.0 s (-6.6%)
Qwen2.5-7B-Instruct 92.44 s (σ=1.6%) 73.72 s (σ=1.1%) -18.7 s (-20.2%) 49.59 s (σ=1.9%) 46.88 s (σ=2.1%) -2.7 s (-5.5%)

3 cold + 3 warm samples per config, interleaved, page cache dropped between cold samples via posix_fadvise(POSIX_FADV_DONTNEED). All A/B pairs measured back-to-back on the same box state.

What's in each commit

Each commit is independently useful and independently measurable; the table above is the cumulative effect.

1. Parent APIServer starts weight-shard page-cache prefetch

A background thread in the parent APIServer opens the model's .safetensors/.bin files and reads them in 16 MB blocks with an 8-thread pool. Reads land in the OS page cache. When EngineCore starts a few seconds later in the child and mmaps the same files, the kernel already has them. Best-effort: any failure (unknown model, permission, etc.) silently falls back to the existing in-child prefetch.

2. Kick torch .so load in background thread at CLI entry

import torch releases the GIL during its .so dlopens (CUDA kernels etc.). That's ~2 s of I/O we currently do serially before vllm/__init__.py even starts. Kick it from vllm/entrypoints/cli/main.py before from vllm.logger import init_logger so the main thread's non-torch imports (envs, stdlib, fastapi, argparse) overlap with torch's .so loading.

3. Pre-spawn forkserver in background thread for vllm serve

cli_env_setup() defaults VLLM_WORKER_MULTIPROC_METHOD=spawn, which costs the EngineCore child process ~5 s of fresh Python startup + import. This commit switches to forkserver and kicks forkserver.ensure_running() with vllm.v1.engine.async_llm preloaded — on a background thread at CLI entry. The ~3-5 s preload-fork overlaps with the parent's argparse + config resolution. When Process.start() time arrives, the forkserver is already warm; the child is a cheap fork instead of a fresh interpreter.

Includes the envs.py update that widens VLLM_WORKER_MULTIPROC_METHOD's allowed choices to include forkserver.

4. Also BG-preload transformers alongside torch

Chain import transformers after import torch in the existing BG thread. Another ~2 s of cold-disk import work, gated on torch being done, that currently blocks the main thread. Since transformers is always imported unconditionally on the vllm serve path (via vllm/transformers_utils/config.py's top-level from transformers import ...), preloading is safe and never wasted.

5. Also prefetch tokenizer + config sidecar files

Extend _startup_prefetch_weights to glob .json, tokenizer.model, and *tokenizer* files from the HF snapshot dir alongside the weight shards. Tiny files, but EngineCore opens several of them synchronously during tokenizer + hf_config init before anything else can progress.

Why these shapes and not simpler alternatives

  • Why BG threads vs. multiprocessing.Process? Each of these operations is either I/O-bound (releases the GIL — torch dlopen, file reads, transformers import) or releases the GIL through syscalls (forkserver.ensure_running() blocks on a pipe). A thread is sufficient and ~100x cheaper than a subprocess.

  • Why put the prewarm in cli/main.py rather than api_server.py? cli/main.py runs before argparse and vllm/__init__.py. That's the earliest possible hook on the user-facing vllm serve path and leaves the most wall-clock budget for the BG work to complete before its result is needed.

  • Why gate the forkserver switch on sys.argv[1] == \"serve\"? vllm bench, vllm chat, and friends don't spawn long-running subprocesses; for them, spawn is fine and forkserver only adds fixed setup cost.

Correctness and compatibility

  • No change to accuracy, sampling, KV-cache behavior, attention, or any user-visible runtime behavior.
  • All prefetch paths are best-effort: any failure falls through to the existing in-child prefetch.
  • The forkserver path is a pre-existing supported multiproc method in Python's stdlib; we just opt vllm serve into it rather than the default spawn.
  • torch and transformers being imported on a BG thread is safe because Python's import lock serializes concurrent imports of the same module — the main thread's later import torch / import transformers just returns the cached module.

Test plan

  • Measurement harness runs back-to-back main vs. this branch on 0.5B and 7B (done, table above)
  • vllm serve Qwen/Qwen2.5-0.5B-Instruct and vllm serve Qwen/Qwen2.5-7B-Instruct produce correct chat completions (done)
  • pytest tests/entrypoints/cli/ — CLI dispatch paths
  • pytest tests/entrypoints/openai/ — API server path

Duplicate-work checks

gh pr list --repo vllm-project/vllm --state open --search "forkserver startup"
gh pr list --repo vllm-project/vllm --state open --search "cold start"
gh pr list --repo vllm-project/vllm --state open --search "weight prefetch"

No overlapping open PRs as of $(date -u +%Y-%m-%d). #40068 is about multi-rank I/O contention (different problem). The several torch.compile cold-start PRs in flight target compile-cache reuse and are orthogonal to parallelizing imports / I/O.

AI-assist disclosure

AI-assisted (Claude via Claude Code). I reviewed every changed line, ran the test plan above locally, and defend the change end-to-end. The measurement harness and box are mine.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the frontend label Apr 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several startup optimizations for vLLM, including background preloading of heavy modules like torch and transformers, pre-warming the multiprocessing forkserver, and prefetching model weights into the OS page cache. Feedback suggests setting the forkserver pre-warm thread to daemon mode to prevent the CLI from hanging during early exits and moving synchronous I/O operations in the weight prefetcher into the background thread to avoid blocking the asynchronous event loop.

Comment thread vllm/entrypoints/cli/main.py Outdated
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "forkserver")
_threading.Thread(
target=_bg_prewarm_forkserver,
daemon=False, # must survive so spawn can use it
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The vllm-forkserver-prewarm thread is initialized with daemon=False. This can cause the CLI process to hang during exit if the pre-warming process (which involves heavy imports and starting a separate process) is still active. For example, if the process exits early due to invalid arguments or an immediate error during initialization, it will be blocked from exiting until this background thread completes forkserver.ensure_running(), which can take several seconds. Setting daemon=True would allow the process to exit immediately in such cases, while still allowing the pre-warm to complete if the process continues to run.

Suggested change
daemon=False, # must survive so spawn can use it
daemon=True, # must survive so spawn can use it

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d96d0a3 — switched to daemon=True. The forkserver subprocess is tracked by module-level state in multiprocessing.forkserver and survives the BG thread exiting, so subsequent spawn() calls can still reuse it. This avoids hanging early CLI exits (bad args, import errors, --help).

Comment thread vllm/entrypoints/openai/api_server.py Outdated
Comment on lines +89 to +132
model_ref = vllm_config.model_config.model
candidate_dir: str | None = None

# 1. Local path?
if os.path.isdir(model_ref):
candidate_dir = model_ref
else:
# 2. HF repo id — try to resolve to the local cache snapshot dir.
# Include tokenizer / config sidecar files so they're warm in the
# page cache too; EngineCore re-opens them during tokenizer init.
try:
from huggingface_hub import snapshot_download

candidate_dir = snapshot_download(
repo_id=model_ref,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.json",
"*tokenizer*",
],
local_files_only=True,
)
except Exception:
return # not cached yet or not an HF repo id

if not candidate_dir or not os.path.isdir(candidate_dir):
return

# Weight shards: large files, read into page cache.
shard_paths = sorted(
glob.glob(os.path.join(candidate_dir, "*.safetensors"))
+ glob.glob(os.path.join(candidate_dir, "*.bin"))
)
# Tokenizer/config sidecars: small, but re-opened in the child and
# add synchronous open+read latency when the disk is cold.
sidecar_paths = sorted(
glob.glob(os.path.join(candidate_dir, "*.json"))
+ glob.glob(os.path.join(candidate_dir, "tokenizer.model"))
+ glob.glob(os.path.join(candidate_dir, "*tokenizer*"))
)
shard_paths.extend(sidecar_paths)
if not shard_paths:
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code performs multiple synchronous I/O operations, including directory checks, Hugging Face Hub cache resolution (snapshot_download), and multiple glob.glob calls. Since _startup_prefetch_weights is called from an asynchronous context (build_async_engine_client_from_engine_args), these blocking calls will stall the event loop, potentially delaying the startup of the API server and other concurrent tasks. All directory resolution and file discovery logic should be moved inside the background thread.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d96d0a3 — moved all synchronous I/O (directory check, snapshot_download, and the glob.glob calls) inside the background thread. _startup_prefetch_weights now only reads three scalar fields off vllm_config synchronously (model, revision, download_dir) and launches the thread immediately; the event loop is no longer blocked.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1732ca62f1

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm/entrypoints/openai/api_server.py Outdated
# Best-effort: if the model is a local path, glob for safetensors; if
# it's a repo-id, try to resolve via HF hub's local cache. Any failure
# silently falls through to the existing in-child prefetch path.
_startup_prefetch_weights(vllm_config)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip parent prefetch in API workers without engine ownership

This prefetch runs unconditionally in every API server process, including multi-API-server workers that only connect to already-launched engines via client_config and do not perform weight loading themselves. In --api-server-count > 1 deployments this makes multiple workers concurrently scan the same model shards, which can heavily contend with actual engine startup I/O and increase startup latency/timeouts instead of reducing them. Gate this to the process that owns engine launch (for example when client_config is absent, or a single designated rank).

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d96d0a3 — added a guard at the call site in build_async_engine_client_from_engine_args that skips _startup_prefetch_weights when client_config contains input_address (which is the marker for an API-only worker connecting to an already-running EngineCore via v1.utils.APIServerProcessManager). In multi-API-server / disaggregated setups only the process that actually launches the engine now prefetches.

Comment thread vllm/entrypoints/openai/api_server.py Outdated
Comment on lines +102 to +104
candidate_dir = snapshot_download(
repo_id=model_ref,
allow_patterns=[
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Resolve prefetch snapshot with configured revision/cache dir

The parent-side prefetch lookup only passes repo_id and local_files_only, so it can target a different snapshot than the one actually used for model loading when users set a non-default --revision or --download-dir. In those cases this thread may read large, irrelevant files (default revision/default HF cache), while the real load path remains cold, causing unnecessary I/O and lost prefetch benefit. Use the same revision/cache_dir inputs as the weight-loading path.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d96d0a3snapshot_download now receives revision=vllm_config.model_config.revision and cache_dir=vllm_config.load_config.download_dir, matching the weight-loader's inputs. If either is non-default we now resolve and prefetch the same snapshot the engine will load.

Comment thread vllm/entrypoints/openai/api_server.py Outdated
# Include tokenizer / config sidecar files so they're warm in the
# page cache too; EngineCore re-opens them during tokenizer init.
try:
from huggingface_hub import snapshot_download
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also do it this way: VLLM_USE_MODELSCOPE=1?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Addressed in d96d0a3 — the prefetch worker now branches on envs.VLLM_USE_MODELSCOPE and uses modelscope.hub.snapshot_download in that case (matching vllm/transformers_utils/repo_utils.py's get_model_path).

@simon-mo
Copy link
Copy Markdown
Collaborator Author

@claude review

Kick off `import torch` and `import transformers` in a background thread
at the top of `vllm/entrypoints/cli/main.py`, before vllm/__init__.py
runs on the main thread. Both modules spend ~2 s on cold-disk .so/module
loading; the GIL is released during file I/O, so non-torch/transformers
imports on the main thread (stdlib, argparse, fastapi, etc.) can make
progress while the BG thread pays that cost. Subsequent `from torch ...`
and `from transformers ...` lines on the main thread hit a warm module
cache.

`transformers` depends on torch, so it's chained into the same thread
after torch (a separate thread would just wait on the import lock).

Both imports are imported unconditionally by api_server's top-level
imports, so the preload is safe and never wasted.

Co-authored-by: Claude
Signed-off-by: simon-mo <simon@inferact.ai>
For `vllm serve`, default VLLM_WORKER_MULTIPROC_METHOD to `forkserver`
(opt-in in envs.py) and pre-spawn the forkserver process from a
background thread at CLI entry.

cli_env_setup() previously defaulted to `spawn`, which costs the child
process ~5 s of fresh Python startup + vllm imports when
AsyncLLM.from_vllm_config spawns EngineCore. `forkserver` forks EngineCore
from an already-warm subprocess preloaded with vllm.v1.engine.async_llm,
avoiding that re-import cost.

The prewarm runs in a daemon=True thread so early CLI exits (bad args,
import errors, --help) don't block on ensure_running(). The forkserver
subprocess itself is tracked by module-level state in
multiprocessing.forkserver and survives this thread exiting; subsequent
spawn() calls reuse it.

api_server.py's existing forkserver setup path is made tolerant of
set_start_method being called twice (the CLI prewarm runs first, then
build_async_engine_client runs again on the main thread).

envs.py Literal and env_with_choices list are updated to include
"forkserver" as a valid value.

Co-authored-by: Claude
Signed-off-by: simon-mo <simon@inferact.ai>
Kick off a background-thread prefetch of the model's safetensors shards
and tokenizer / config sidecar files from the PARENT APIServer process as
soon as vllm_config is resolved. EngineCore's existing in-child prefetch
then races against a warm page cache that the parent started filling
~10-15 s earlier.

At 32 B weights, the prefetch+load phase was 28-30 s cold with the child
doing all the I/O. Moving the kick-off to parent overlaps it with the
'APIServer bootstrap + spawn gap' phase that was otherwise idle for disk.

Implementation:
- All work (directory resolution, HF/ModelScope cache lookup, globbing,
  and the reads themselves) runs inside the background thread so the
  asyncio event loop is never blocked.
- Resolve model ref to a local directory (local path OR cache snapshot
  via snapshot_download(local_files_only=True)). Uses the same revision
  and cache_dir (`--download-dir`) the weight loader will use, so we
  prefetch the same snapshot that gets loaded.
- Honours VLLM_USE_MODELSCOPE and falls back to modelscope's
  snapshot_download when set.
- Glob .safetensors/.bin shards plus *.json / tokenizer.* sidecars;
  spawn daemon thread reading each in 16 MB blocks using 8 concurrent
  workers (mirrors vLLM's in-child prefetch block size + thread count).
  Sidecar files are small but EngineCore opens several synchronously
  during tokenizer + hf_config init, so warming their pages helps too.
- Skip when the APIServer is a headless API worker (client_config
  contains `input_address`) — those processes never load weights, and
  prefetching from every worker would contend with the real engine I/O.
- All failures silently swallowed; in-child prefetch then runs normally.

Co-authored-by: Claude
Signed-off-by: simon-mo <simon@inferact.ai>
@simon-mo simon-mo force-pushed the startup/parallelization branch from 1732ca6 to d96d0a3 Compare April 20, 2026 21:33
@simon-mo
Copy link
Copy Markdown
Collaborator Author

Rebased onto latest main and squashed down to 3 logically cohesive commits:

  • a50b6640 — BG-preload torch + transformers at CLI entry
  • 0d3bd8ce — Pre-spawn forkserver in BG thread (+ envs.py + api_server tolerant re-call)
  • d96d0a34 — Parent-side weight + tokenizer/config sidecar prefetch

Addressed all four automated review comments in d96d0a34:

  1. gemini (forkserver thread daemon) — switched to daemon=True. The forkserver subprocess is tracked by module-level state in multiprocessing.forkserver and survives the BG thread exiting, so spawn()/Process.start() calls still reuse it. Early CLI exits (bad args, --help, import errors) no longer hang.
  2. gemini (sync I/O in async context) — moved all blocking work (isdir, snapshot_download, glob) inside the background thread. The outer function only reads three scalar fields off vllm_config synchronously.
  3. codex P1 (skip prefetch in headless API workers) — call site now guards on client_config.get("input_address"). Multi-API-server / disaggregated deployments only prefetch from the process that owns engine launch.
  4. codex P2 (use revision/download_dir)snapshot_download now receives revision=model_config.revision and cache_dir=load_config.download_dir, matching the weight loader.
  5. @chaunceyjiang (modelscope) — prefetch worker now branches on envs.VLLM_USE_MODELSCOPE and uses modelscope.hub.snapshot_download in that case (mirrors vllm/transformers_utils/repo_utils.py::get_model_path).

Also fixed DCO: all 3 commits now carry Signed-off-by.

Original measurements still hold (no change to the hot paths):

  • Qwen2.5-0.5B: cold -9.0 s (-11.5%), warm -3.0 s (-6.6%)
  • Qwen2.5-7B: cold -18.7 s (-20.2%), warm -2.7 s (-5.5%)

AI assistance was used (Claude) for drafting the code and review replies; I (simon-mo) reviewed every changed line, ran the benchmarks, and take full ownership of the change.

@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 20, 2026
Copy link
Copy Markdown
Collaborator

@chaunceyjiang chaunceyjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chaunceyjiang chaunceyjiang merged commit 8256833 into vllm-project:main Apr 21, 2026
57 checks passed
mmangkad added a commit to mmangkad/vllm that referenced this pull request Apr 21, 2026
@chaunceyjiang
Copy link
Copy Markdown
Collaborator

@simon-mo This PR seems to have introduced some new issues.

Traceback (most recent call last):
  File "/mnt/data4/jxy/venv/bin/vllm", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/mnt/data4/jxy/vllm/vllm/entrypoints/cli/main.py", line 87, in main
    import vllm.entrypoints.cli.benchmark.main
  File "/mnt/data4/jxy/vllm/vllm/entrypoints/cli/benchmark/main.py", line 10, in <module>
    from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
  File "/mnt/data4/jxy/vllm/vllm/entrypoints/utils.py", line 19, in <module>
    from vllm.engine.arg_utils import EngineArgs
  File "/mnt/data4/jxy/vllm/vllm/engine/arg_utils.py", line 35, in <module>
    from vllm.config import (
  File "/mnt/data4/jxy/vllm/vllm/config/__init__.py", line 20, in <module>
    from vllm.config.model import (
  File "/mnt/data4/jxy/vllm/vllm/config/model.py", line 30, in <module>
    from vllm.transformers_utils.config import (
  File "/mnt/data4/jxy/vllm/vllm/transformers_utils/config.py", line 18, in <module>
    from transformers import GenerationConfig, PretrainedConfig
ImportError: cannot import name 'GenerationConfig' from 'transformers' (/mnt/data4/jxy/venv/lib/python3.12/site-packages/transformers/__init__.py)


@https://buildkite.com/vllm/ci/builds/62254#019daea0-fbef-4d6b-b423-a1c5a5d2b947

Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request Apr 22, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Apr 23, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
Signed-off-by: Yifan <yzong@redhat.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
Signed-off-by: Adrian <info@zzit.ch>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…orkserver prewarm (vllm-project#40331)

Signed-off-by: simon-mo <simon@inferact.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants