[skyrl] vLLM Renderer for rendering Multi-Modal ModelInputChunks for training backend#1464
Conversation
| from skyrl.train.config import SkyRLTrainConfig | ||
| from tests.backends.skyrl_train.gpu.utils import InferenceEngineState | ||
|
|
||
| requires_local_vllm = pytest.mark.skipif( |
There was a problem hiding this comment.
All of the current vision language rendering and generation tests require a local vLLM install from main till the next vllm release
| class RenderedModelInput(BaseModel): | ||
| prompt_ids: list[int] | ||
| multi_modal_kwargs: dict[str, bytes] | None = None | ||
| multi_modal_kwargs: dict[str, list[str]] | None = None |
There was a problem hiding this comment.
Would it make sense to make this a TypedDict? To make it easier to understand which keys can be there (even if they are optional)
There was a problem hiding this comment.
That makes sense to me. Added it as a typed dict, but I kept the value typing as any instead of a torch.Tensor since types.py is also used by the jax backend.
| ] | ||
|
|
||
|
|
||
| def decode_mm_kwargs(rendered: RenderedModelInput) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
It seems more natural to me to put only the multi_modal_kwargs as an argument (which will go well with the suggestion below to introduce a typeddict for it), since that's the only part of the renderedmodelinput that is used here
There was a problem hiding this comment.
This looks great to me! It is slightly sad that we depend on
from vllm.entrypoints.serve.disagg.mm_serde import (
decode_mm_kwargs_item as _vllm_decode,
)
which seems more of an internal vllm API (and depending on it violates the client / server separation for vllm a little). If the messagepack protocol is stable, maybe we would want to replicate it in here. We can also do that going forward (e.g. maybe in the future it makes sense to have skyrl/backends/renderer.py not depend on vllm and put the VLLMRenderer into the skyrl_train folder. So feel free to just move forward with the PR for now :)
## Summary Integrates the VLLMRenderer (landed in #1464) into the SkyRL train backend so that VLM training batches include image placeholder tokens and decoded vision tensors (`pixel_values`, `image_grid_thw`). - When using new inference (`_SKYRL_USE_NEW_INFERENCE`), `_to_training_batch` lazily creates a `VLLMRenderer` and renders all `ModelInput`s through it. - Extracts `pixel_values` and `image_grid_thw` from rendered outputs and adds them to the `TrainingInputBatch` as `TensorList` entries (one tensor per batch element, since patch counts vary per image). - Extends `_pad_batch` to handle `TensorList` fields by cycling and cloning entries, matching the existing padding strategy for regular tensors. - Reorders `forward_backward` and `forward` to call `_to_training_batch` before `_sleep_inference_engines`, since the renderer needs the inference servers need to be initialized. Note that this does not wake the KV cache or model GPU memory since that is explicitly done in `save_weights_for_sampler` via the dispatcher. ## E2E Tinker VLM Classifier Curves With #1484 , we can now run tinker vision language recipes against the SkyRL. Merging both closes #1200 Example: ```bash _SKYRL_USE_NEW_INFERENCE=1 uv run --extra tinker --extra fsdp -m skyrl.tinker.api \ --base-model "Qwen/Qwen3-VL-8B-Instruct" \ --backend fsdp \ --backend-config '{"trainer.placement.policy_num_gpus_per_node": 8, "generator.inference_engine.num_engines": 8, "trainer.placement.colocate_all": true, "trainer.use_sample_packing": false}' ``` Cookbook ```bash TINKER_API_KEY=tml-dummy uv run --with tinker --with datasets --with torch python -m \ tinker_cookbook.recipes.vlm_classifier.train \ base_url=http://localhost:8000 \ model_name="Qwen/Qwen3-VL-4B-Instruct" \ dataset=caltech101 ``` Train nll: <img width="1200" height="675" alt="train_nll" src="https://github.com/user-attachments/assets/82e36767-edee-43b7-ab4a-7fbf496c8cbb" /> Val nll: <img width="1200" height="675" alt="val_nll" src="https://github.com/user-attachments/assets/1dc6e96b-7e1b-4ead-bf0e-71e42eab0491" /> Val accuracy: <img width="1200" height="675" alt="accuracy" src="https://github.com/user-attachments/assets/ec6f92b8-a544-42d9-9a00-4c06292e7ae3" /> <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1496" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
Summary
This PR introduces a vLLM renderer which used by the SkyRL backend to convert
ModelInputChunks into tokenized text,pixel_values, andimage_grid_thw. The scope of this PR is limited to the actual renderer implementation, not adding to the backend yet.VLLMRenderer.RenderedImageNamedTuple return type from_render_images.render_model_inputfor the text-only fast path in_render_singleinstead of duplicating the token concatenation logic.RenderedModelInput.multi_modal_kwargstype annotation fromdict[str, bytes]todict[str, list[str]]to match actual usage (list of base64-encoded strings per modality key).VLLMRendererwith a mockedRemoteInferenceClient(text-only, image-only, mixed, error cases).test_vlm_renderer.py) that exercises the full renderer against a real VLM viaInferenceEngineState.SKYRL_LOCAL_VLLM=1since they depend on a local vLLM fork with/v1/chat/completions/rendersupport that is not yet upstreamed.Test plan
python -m pytest tests/backends/skyrl_train/test_renderer.py -v(8 tests, no GPU required)SKYRL_LOCAL_VLLM=1 uv run --extra fsdp --extra dev --extra tinker pytest tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_vlm_inference_generation.py -m vllm -v