Skip to content

Multimodal serve support #45220

Merged
SunMarc merged 19 commits intomainfrom
audio-video-serve
Apr 15, 2026
Merged

Multimodal serve support #45220
SunMarc merged 19 commits intomainfrom
audio-video-serve

Conversation

@SunMarc
Copy link
Copy Markdown
Member

@SunMarc SunMarc commented Apr 3, 2026

What does this PR do?

This PR adds transformers serve compatibility to multimodal models like qwen omni or gemma 4. We add support for audio with chat completion and response though input_audio -> the client need to base64-encode the audio and send it as input_audio.

For video, OpenAI API doesn't natively support video_url as a content type. So we extended it so that we can still play with it. For simplicity also, we also allow to pass url for audio through audio_url.

Results (tested with google/gemma-4-E2B-it and Qwen/Qwen2.5-Omni-3B)

import base64
import socket
import time

import httpx
from openai import OpenAI

from transformers.cli.serve import Serve

AUDIO_URL = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"
VIDEO_URL = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"
MODEL = "google/gemma-4-E2B-it"

# Qwen Omni
# MODEL = "Qwen/Qwen2.5-Omni-3B"


def find_free_port():
    with socket.socket() as s:
        s.bind(("", 0))
        return s.getsockname()[1]


def start_serve():
    port = find_free_port()
    serve = Serve(port=port, non_blocking=True)
    for _ in range(30):
        try:
            if httpx.get(f"http://localhost:{port}/health", timeout=2).status_code == 200:
                return serve, port
        except Exception:
            pass
        time.sleep(1)
    raise RuntimeError("Server did not start in time")


serve, port = start_serve()
client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="unused")

audio_bytes = httpx.get(AUDIO_URL, follow_redirects=True).content
audio_b64 = base64.b64encode(audio_bytes).decode()

print("=== Audio via responses API ===")
resp = client.responses.create(
    model=MODEL,
    input=[
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Transcribe this audio."},
                {"type": "input_audio", "input_audio": {"data": audio_b64, "format": "mp3"}},
            ],
        }
    ],
    stream=False,
    max_output_tokens=200,
)
print(resp.output[0].content[0].text)
print()

# --- Video with audio (responses API) ---
print("=== Video via responses API ===")
resp = client.responses.create(
    model=MODEL,
    input=[
        {
            "role": "user",
            "content": [
                {"type": "video_url", "video_url": {"url": VIDEO_URL}},
                {"type": "text", "text": "Transcribe the lyrics of the song being played in this video."},
            ],
        }
    ],
    stream=False,
    max_output_tokens=500,
)
print(resp.output[0].content[0].text)
print()

serve.kill_server()
print("Done!")
Audio via responses API
This week, I traveled to Chicago to deliver my final farewell address to the nation, following in the tradition of
presidents before me. It was an opportunity to say thank you. Whether we've seen eye-to-eye or rarely agreed at
all, my conversations with you, the American people, in living rooms and schools, at farms and on factory floors,
at diners, and on distant military outposts, all these conversations are what have kept me honest.

Video via responses API
(Song lyrics)

I don't care how straight
From neck to chest
We're in the same predicament
Another one wantin' is in the storm alone
I'm the one down below this
You don't wanna be my
I never thought you'd say
Of this nice sad place you've been
I don't want it my face
But I don't wanna die

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc requested a review from LysandreJik April 3, 2026 14:57
return_tensors=None if use_cb else "pt",
return_dict=True,
tokenize=True,
load_audio_from_video=modality == Modality.MULTIMODAL and has_video,
Copy link
Copy Markdown
Member Author

@SunMarc SunMarc Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

managed to use it but torchcodec is required for that. Otherwise, it fails back to the other lib and it fails. Also torchcodec + ffmpeg was a bit of a pain to install correctly.
Also, we should maybe should force the user to install torchcodec when using this no cc @eustlb ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed load_audio only works with torchcodec as video containers are not supported by librosa. Agree we need to raise a clear error, let me raise a PR for that

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this will prob error out when video is passes and it is silent, iirc old torchcodec would complain that no audio is found

Also the whole load_audio_from_video is mostly heuristic for models except for qwen-omni, no guarantee on performance

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this is better if we can delegate this is the audio processing, like have a try catch when trying to load the model.

@SunMarc SunMarc requested a review from eustlb April 3, 2026 15:13
Copy link
Copy Markdown
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc what would you think of having an ALM modality, and to differentiate:
ALM audio + text
VLM vision + text
MULTIMODEL EDIT MULTIMODAL (my bad typo) audio + vision + text

Comment on lines +122 to +125
if load_audio_from_video and not is_torchcodec_available():
raise ValueError(
"Extracting audio from video requires `torchcodec`. Install it with: `pip install torchcodec`."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this can be removed, better to locate the error in load_audio

Comment on lines +914 to +916
All modalities extract text. VLM additionally handles ``image_url`` and ``video_url``.
MULTIMODAL handles all of the above plus ``input_audio`` and ``audio_url``.
For LLMs, the content parts are collapsed into a plain text string.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My issue with this is that ALMs are seen as a sub-category of omni, while is the case of gemma4 but for others models too, we can use the ALM and VLM separately, and together. This makes even more sense knowing that audio + vision is emergent capability: the model as not been trained on both

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the separation line now looks weird with audio LLMs. Actually I was planning to keep a single MULTIMODAL key for all types and combinations. Tho first it needs to be aligned with Lucain and the hub

@SunMarc
Copy link
Copy Markdown
Member Author

SunMarc commented Apr 3, 2026

@SunMarc what would you think of having an ALM modality, and to differentiate:
ALM audio + text
VLM vision + text
MULTIMODEL audio + vision + text

Maybe we should have a MODEL_FOR_AUDIO_TEXT_MAPPING_NAMES mapping ? I created MULTIMODEL as we have this MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES . We have MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES but this is more for decoder-encoder style model, so we can't really use that no . cc @eustlb

@SunMarc
Copy link
Copy Markdown
Member Author

SunMarc commented Apr 14, 2026

run-slow: serve

@SunMarc
Copy link
Copy Markdown
Member Author

SunMarc commented Apr 14, 2026

run-slow: cli

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, looks good! Implementation is nice.

Let's please add some docs which go over the new features, each with examples of how to use the feature/new arguments

Comment thread src/transformers/cli/serving/utils.py Outdated
Comment on lines +120 to +123
# Default to 32 frames for video (Gemma 4 default); some processors load all frames otherwise
chat_template_kwargs = {}
if has_video:
chat_template_kwargs["num_frames"] = 32
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) should we apply this to only gemma 4 then? but maybe easier to do down the road when adding support for other video models

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemma 4 have this default of 32 frames which is okay since it is coming from their official implementation. For now, I was thinking about hardcoding this to 32 because otherwise, with qwen omni, you get OOM very quickly even if the video is 10s for example. But we should definitely improve that and probably set the fps instead or another default. cc @zucchini-nlp

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, if videos models are supported when serving, there needs to be a default sampling arg or users need to pass it explicitly. In inference users usually are encouraged to pass a value of their own, because most video processor classes don't have a default

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's fix this in a follow-up PR then. Let's keep this for now, so that it runs smoothly for all kind of models

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["cli"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN d35cca92 workflow commit (merge commit)
PR e85c9c1f branch commit (from PR)
main 27fbb514 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding my 5 cents 😄

Comment thread src/transformers/cli/serving/chat_completion.py Outdated
return_tensors=None if use_cb else "pt",
return_dict=True,
tokenize=True,
load_audio_from_video=modality == Modality.MULTIMODAL and has_video,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this will prob error out when video is passes and it is silent, iirc old torchcodec would complain that no audio is found

Also the whole load_audio_from_video is mostly heuristic for models except for qwen-omni, no guarantee on performance

Comment on lines +914 to +916
All modalities extract text. VLM additionally handles ``image_url`` and ``video_url``.
MULTIMODAL handles all of the above plus ``input_audio`` and ``audio_url``.
For LLMs, the content parts are collapsed into a plain text string.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the separation line now looks weird with audio LLMs. Actually I was planning to keep a single MULTIMODAL key for all types and combinations. Tho first it needs to be aligned with Lucain and the hub

Comment on lines +951 to +956
if "base64" in url:
image_data = re.sub("^data:image/.+;base64,", "", url)
image = Image.open(BytesIO(base64.b64decode(image_data)))
file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(file.name)
url = file.name
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can decode images from base 64 in processing, or we don't call the processor's loading?

# Try to load as base64
try:
b64 = base64.decodebytes(image.encode())
image = PIL.Image.open(BytesIO(b64))

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll update that in a follow-up PR. It was there from the start but it looks like we can indeed simplify this a bit

@SunMarc
Copy link
Copy Markdown
Member Author

SunMarc commented Apr 15, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 15, 2026

Style fix fix runs successfully without any file modified.

@SunMarc SunMarc added this pull request to the merge queue Apr 15, 2026
Merged via the queue into main with commit fd45a42 Apr 15, 2026
18 checks passed
@SunMarc SunMarc deleted the audio-video-serve branch April 15, 2026 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants