Skip to content

Use torchvision decode_image to load images in the torchvision backend#45195

Merged
yonigozlan merged 7 commits intohuggingface:mainfrom
yonigozlan:torchvision-decode-image
Apr 9, 2026
Merged

Use torchvision decode_image to load images in the torchvision backend#45195
yonigozlan merged 7 commits intohuggingface:mainfrom
yonigozlan:torchvision-decode-image

Conversation

@yonigozlan
Copy link
Copy Markdown
Member

@yonigozlan yonigozlan commented Apr 2, 2026

What this PR does

Adds a new load_image_as_tensor utility leveraging torchvision's decode_image to image_utils.py and overrides fetch_images in TorchvisionBackend to use it. Previously, all image loading went through PIL regardless of which backend was used.


Benchmarks

Hardware: NVIDIA A10G + CPU (AWS)
Method: 20 repetitions, median, 3 warm-up runs

Image loading

format size PIL (ms) torchvision (ms) speedup
JPEG 224×224 0.77 0.64 +17.5%
JPEG 512×512 3.19 2.99 +6.2%
JPEG 1024×1024 12.29 11.60 +5.6%
PNG 224×224 0.83 0.73 +11.3%
PNG 512×512 4.93 4.52 +8.4%
PNG 1024×1024 19.15 17.93 +6.4%

Pixel values are identical between both paths for JPEG and PNG across all sizes tested.

End-to-end processor pipeline

The pipeline speedups exceed the standalone loading gains because the old path also ran pil_to_tensor inside process_image, which is skipped entirely when the image is already a tensor.

processor img size device batch PIL (ms) path (ms) speedup
SiglipImageProcessor 224×224 cpu 1 1.73 1.38 +20%
SegformerImageProcessor 224×224 cpu 1 3.68 3.20 +13%
SiglipImageProcessor 512×512 cpu 1 6.58 5.55 +16%
SegformerImageProcessor 512×512 cpu 1 5.61 4.37 +22%
SiglipImageProcessor 1024×1024 cpu 1 20.52 17.01 +17%
SegformerImageProcessor 1024×1024 cpu 1 22.13 16.74 +24%
SiglipImageProcessor 224×224 cpu 16 17.10 14.11 +17%
SegformerImageProcessor 512×512 cpu 16 98.85 91.76 +7%
SiglipImageProcessor 1024×1024 cpu 16 278.42 233.92 +16%
SegformerImageProcessor 1024×1024 cpu 16 306.07 264.34 +14%
SiglipImageProcessor 224×224 cuda 4 3.71 3.11 +16%
SegformerImageProcessor 512×512 cuda 4 14.37 12.88 +10%
SiglipImageProcessor 1024×1024 cuda 4 55.98 48.34 +14%
SegformerImageProcessor 1024×1024 cuda 4 56.11 48.26 +14%
SiglipImageProcessor 1024×1024 cuda 16 248.25 191.55 +23%
SegformerImageProcessor 1024×1024 cuda 16 222.62 191.51 +14%

Summary

scenario speedup
Standalone image loading (local file) +6–17%
Standard processors, CPU +7–24%
Standard processors, CUDA +10–23%

Benchmarks script:

Click to expand code
"""
Benchmark: PIL fetch_images vs torchvision load_image_as_tensor

Sections
--------
1. Correctness  – pixel-level comparison of PIL vs torchvision decoded images
2. Load         – isolated load_image / load_image_as_tensor timing
3. Pipeline     – end-to-end processor() call with PIL inputs vs tensor inputs

Usage
-----
    conda run -n hf_latest1 python .debug/scripts/benchmark_fetch_images.py
"""

import io
import os
import sys
import tempfile
import time
from pathlib import Path

import numpy as np

# ---------------------------------------------------------------------------
# Setup: make sure repo src is importable
# ---------------------------------------------------------------------------
REPO = Path(__file__).resolve().parents[2] / "transformers" / "src"
if str(REPO) not in sys.path:
    sys.path.insert(0, str(REPO))

import torch
import torchvision
from PIL import Image as PILImage
from transformers import SiglipImageProcessor, SegformerImageProcessor
from transformers.image_utils import load_image, load_image_as_tensor
from transformers.image_processing_base import ImageProcessingMixin

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
WARM = 3      # warm-up iterations
REPS = 20     # timed iterations

DIVIDER     = "=" * 72
SUBDIV      = "-" * 72
COL         = 28          # label column width

def hdr(title):
    print(f"\n{DIVIDER}")
    print(f"  {title}")
    print(DIVIDER)

def sub(title):
    print(f"\n  {title}")
    print(f"  {'-' * (len(title))}")

def row(label, value, unit="ms"):
    print(f"  {label:<{COL}}{value:>10.3f}  {unit}")

def pct(label, a, b):
    """Print speedup of b over a (positive = b is faster)."""
    delta = (a - b) / a * 100
    sign  = "faster" if delta > 0 else "slower"
    print(f"  {label:<{COL}}{abs(delta):>10.1f}% {sign}")


def timeit(fn, reps=REPS, warm=WARM):
    """Return (median_ms, mean_ms, std_ms) over `reps` repetitions."""
    # warm-up
    for _ in range(warm):
        fn()
    times = []
    for _ in range(reps):
        t0 = time.perf_counter()
        fn()
        times.append((time.perf_counter() - t0) * 1000)
    times = np.array(times)
    return float(np.median(times)), float(np.mean(times)), float(np.std(times))


def make_test_image(width, height, fmt="JPEG", quality=95) -> bytes:
    """Create a random RGB image and return its bytes in `fmt`."""
    arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
    img = PILImage.fromarray(arr)
    buf = io.BytesIO()
    img.save(buf, format=fmt, quality=quality)
    return buf.getvalue()


# ---------------------------------------------------------------------------
# 1. CORRECTNESS
# ---------------------------------------------------------------------------
hdr("1 · CORRECTNESS  –  PIL vs torchvision decode")

SIZES   = [(224, 224), (512, 512), (1024, 1024)]
FORMATS = ["JPEG", "PNG"]

for fmt in FORMATS:
    sub(fmt)
    for w, h in SIZES:
        raw = make_test_image(w, h, fmt=fmt)
        with tempfile.NamedTemporaryFile(suffix=f".{fmt.lower()}", delete=False) as f:
            f.write(raw)
            fpath = f.name
        try:
            pil_img   = load_image(fpath)                              # PIL.Image → RGB
            tv_tensor = load_image_as_tensor(fpath)                    # [C,H,W] uint8

            pil_arr   = np.array(pil_img)                              # [H,W,3]
            tv_arr    = tv_tensor.permute(1, 2, 0).numpy()             # [H,W,3]

            max_diff  = int(np.abs(pil_arr.astype(int) - tv_arr.astype(int)).max())
            mean_diff = float(np.abs(pil_arr.astype(float) - tv_arr.astype(float)).mean())
            identical = max_diff == 0

            flag = "✓ identical" if identical else f"Δmax={max_diff} Δmean={mean_diff:.4f}"
            print(f"    {w:>5}×{h:<5}  {flag}")
        finally:
            os.unlink(fpath)

print()
print("  Note: JPEG differences (if any) are expected – libjpeg vs libjpeg-turbo.")


# ---------------------------------------------------------------------------
# 2. LOAD BENCHMARK
# ---------------------------------------------------------------------------
hdr("2 · LOAD BENCHMARK  –  load_image vs load_image_as_tensor")

LOAD_SIZES = [(224, 224), (512, 512), (1024, 1024)]
LOAD_FMTS  = ["JPEG", "PNG"]

# -- (optional) URL test: use a well-known public image
URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"

print(f"\n  {'label':<{COL}}{'PIL (ms)':>12}  {'TV (ms)':>12}  {'speedup':>10}")
print(f"  {'-'*58}")


def _fmt_row(label, pil_ms, tv_ms):
    sp = (pil_ms - tv_ms) / pil_ms * 100
    sign = "faster" if sp > 0 else "slower"
    print(f"  {label:<{COL}}{pil_ms:>12.2f}  {tv_ms:>12.2f}  {abs(sp):>8.1f}% {sign}")


for fmt in LOAD_FMTS:
    for w, h in LOAD_SIZES:
        raw = make_test_image(w, h, fmt=fmt)
        with tempfile.NamedTemporaryFile(suffix=f".{fmt.lower()}", delete=False) as f:
            f.write(raw)
            fpath = f.name
        try:
            label = f"{fmt} {w}×{h}"
            pil_med, *_ = timeit(lambda p=fpath: load_image(p))
            tv_med,  *_ = timeit(lambda p=fpath: load_image_as_tensor(p))
            _fmt_row(label, pil_med, tv_med)
        finally:
            os.unlink(fpath)

# URL
try:
    label = "JPEG URL"
    pil_med, *_ = timeit(lambda: load_image(URL))
    tv_med,  *_ = timeit(lambda: load_image_as_tensor(URL))
    _fmt_row(label, pil_med, tv_med)
except Exception as e:
    print(f"  URL test skipped ({e})")


# ---------------------------------------------------------------------------
# 3. PIPELINE BENCHMARK
# ---------------------------------------------------------------------------
hdr("3 · PIPELINE BENCHMARK  –  PIL input vs tensor input vs path input")

print("""
  Three input strategies are compared:
    [pil]    load_image(path) × N  →  processor(images=[pil, …])
    [tensor] load_image_as_tensor(path) × N  →  processor(images=[tensor, …])
    [path]   processor(images=[path, …])  — fetch_images does the decode
""")

DEVICES     = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
BATCH_SIZES = [1, 4, 16]
PIPE_SIZES  = [(224, 224), (512, 512), (1024, 1024)]

PROCESSORS = {
    "SiglipImageProcessor":    SiglipImageProcessor(),
    "SegformerImageProcessor": SegformerImageProcessor(),
}

# Pre-generate one JPEG file per size
_pipe_tmpfiles = {}
for w, h in PIPE_SIZES:
    raw = make_test_image(w, h, fmt="JPEG")
    f = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
    f.write(raw); f.flush()
    _pipe_tmpfiles[(w, h)] = f.name


def _pil_input(path, n):
    return [load_image(path) for _ in range(n)]

def _path_input(path, n):
    return [path] * n


def bench_processor(proc, path, batch, device):
    dev_arg = device if device != "cpu" else None

    def run_pil():
        proc(images=_pil_input(path, batch), device=dev_arg, return_tensors="pt")

    def run_path():
        proc(images=_path_input(path, batch), device=dev_arg, return_tensors="pt")

    pil_ms,  *_ = timeit(run_pil)
    path_ms, *_ = timeit(run_path)
    return pil_ms, path_ms


HDR_FMT = "  {:<26} {:>10}  {:>6}  {:>5}  {:>10}  {:>10}  {:>12}"
ROW_FMT = "  {:<26} {:>10}  {:>6}  {:>5}  {:>10.2f}  {:>10.2f}  {:>11.1f}%"

print(HDR_FMT.format("processor", "img size", "device", "batch", "pil (ms)", "path (ms)", "speedup"))
print("  " + "-" * 92)

for dev in DEVICES:
    for batch in BATCH_SIZES:
        for w, h in PIPE_SIZES:
            path = _pipe_tmpfiles[(w, h)]
            for proc_name, proc in PROCESSORS.items():
                pil_ms, path_ms = bench_processor(proc, path, batch, dev)
                sp = (pil_ms - path_ms) / pil_ms * 100
                print(ROW_FMT.format(proc_name, f"{w}×{h}", dev, batch, pil_ms, path_ms, sp))
        print()

for p in _pipe_tmpfiles.values():
    os.unlink(p)

print(f"\n{DIVIDER}")
print("  Done.")
print(DIVIDER)

Cc @zucchini-nlp @NicolasHug

@yonigozlan yonigozlan force-pushed the torchvision-decode-image branch from ec78ace to e5a7a0f Compare April 2, 2026 18:03
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks 🫡 that makes 100% sense to avoid intermediate conversions too much

  1. For completeness sake can you add your benchmark script in the PR description
  2. I assume this is not sensitive to the torch version but just asking in case

else:
pil_torch_interpolation_mapping = {}
torch_pil_interpolation_mapping = {}
if is_torchvision_available():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General mistake oops 😅

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to have nested imports here but that's just me 😅, I can change back if needed

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good, shouldn't be too bad

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 2, 2026

Lets figure the one failing processor test as well 👀

@yonigozlan
Copy link
Copy Markdown
Member Author

Thanks for the review! The issue with the test should be resolved, and added the source code for the benchmarks

@yonigozlan
Copy link
Copy Markdown
Member Author

yonigozlan commented Apr 3, 2026

@vasqu re your review question: no, this is not sensitive to the torch version. torchvision.io.decode_image, ImageReadMode, and torch.frombuffer have all been available before the transformers minimum torch>=2.4.

@yonigozlan yonigozlan added this pull request to the merge queue Apr 6, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to no response for status checks Apr 6, 2026
@yonigozlan yonigozlan enabled auto-merge April 6, 2026 18:49
Comment on lines +547 to +549
elif isinstance(image, PIL.Image.Image):
image = PIL.ImageOps.exif_transpose(image)
return pil_to_tensor(image.convert("RGB"))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm do we need to convert from PIL when loading? I'd suppose return no-op if a decoded image is provided

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to be consistent with the helper function's name load_image_as_tensor. In practice, in the processors we will have a no op (and a pil_to_tensor later on for torchvision backend) because of this:
https://github.com/yonigozlan/transformers/blob/2b5d481df844c19907c15fab7cd547ccf9f27c7f/src/transformers/image_processing_backends.py#L122-L125

-> We only enter if we have a str, and no op if we have a valid image

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 9, 2026

@yonigozlan sounds good to me, so we can merge? I want to avoid force merging now as CI is kind of flaky atm and has troubles with the network / hub.

You can ping me again if it gets too annoying

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 9, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: colpali

@yonigozlan yonigozlan added this pull request to the merge queue Apr 9, 2026
Merged via the queue into huggingface:main with commit d6a8904 Apr 9, 2026
28 checks passed
@yonigozlan yonigozlan deleted the torchvision-decode-image branch April 9, 2026 17:07
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 9, 2026

Ok just flaky CI ig, it merged itself

sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
…kend (huggingface#45195)

* use torchvision's decode_image to load images for torchvision backend

* fix video processor issue

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants