Skip to content

Add Pixio pre-trained models#42795

Merged
molbap merged 28 commits intohuggingface:mainfrom
LiheYoung:add_pixo
Dec 17, 2025
Merged

Add Pixio pre-trained models#42795
molbap merged 28 commits intohuggingface:mainfrom
LiheYoung:add_pixo

Conversation

@LiheYoung
Copy link
Copy Markdown
Contributor

What does this PR do?

Add Pixo models. Pixo is a capable pre-trained ViT encoder.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LiheYoung LiheYoung changed the title Add pixo pre-trained models Add Pixo pre-trained models Dec 11, 2025
@Rocketknight1
Copy link
Copy Markdown
Member

cc @yonigozlan @molbap

Copy link
Copy Markdown
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening this PR @LiheYoung ! Just a comment before I dive deeper in the modeling code, it seems that the model is very similar to ViT, and I see that you're using # Copied from statements.

We deprecated the use of # Copied from in favor of using modular. I think this model should benefit greatly from using modular, and it would make the reviewing process much simpler.

@LiheYoung
Copy link
Copy Markdown
Contributor Author

@yonigozlan It is similar to ViT, but we use multiple class tokens instead of a single class token. As previously suggested by @molbap, this change is enough to warrant a new pixo directory. Besides, we remove some unnecessary designs from the original ViT in our Pixo, such as the mask tokens.

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Dec 15, 2025

Hey @LiheYoung ! Reviewing a bit more extensively but a clarification: yes, the model is different from ViT due to the multiple CLS tokens, so OK to add a new arch, but we do need to use modular. I will explain in the review.

You can check it out here: https://huggingface.co/docs/transformers/v5.0.0rc1/modular_transformers all new models must be modulars of another model unless they are entirely new, which is not the case here. Having a modular file still creates -automatically- a modeling file, in the new model's directory. It facilitates review and maintenance by a significant factor.

For the unneeded patterns, they do not matter as that code will not be run, and will be absent from the modular file.

@LiheYoung
Copy link
Copy Markdown
Contributor Author

LiheYoung commented Dec 15, 2025

Thanks for reviewing @molbap! The current Pixo code structure strictly follows DINOv2 indeed. I remove some unused configurations and add a n_cls_tokens argument. I found DINOv2 does not use a modular file, so should I import modules from DINOv2 and make some changes? I am worried that it may bring some unexpected behavior if inheriting from DINOv2. For example, we do not contain mask tokens at all.

Do you have any vision model examples that use modular?

Copy link
Copy Markdown
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I have added a modular implementation in a comment that should work out. Thanks for the PR!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not using converters anymore. Instead in v5, we have a file named conversion_mapping.py here where we indicate WeightRenamer operations, which allows us to load on-the-fly from existing checkpoints on the hub.

If you haven't released the weights, you can even change the state dict on the hub, which removes the need from a converter here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight converter does not necessarily replace a converter! Better to do it once at the beginning if we can before release, rather than every time later. It's more for post-release integrations!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ought to be auto-generated. I tested locally and the following should do as a modular file.

# coding=utf-8
# Copyright 2025 Meta AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Pixo model."""

from typing import Optional

import torch
from torch import nn

from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling
from ...utils import auto_docstring, logging, torch_int
from ...utils.generic import check_model_inputs
from ..dinov2.modeling_dinov2 import Dinov2Backbone, Dinov2DropPath, Dinov2MLP
from ..vit.modeling_vit import ViTAttention, ViTPatchEmbeddings, ViTPreTrainedModel
from .configuration_pixo import PixoConfig


logger = logging.get_logger(__name__)


class PixoPatchEmbeddings(ViTPatchEmbeddings):
    pass


class PixoEmbeddings(nn.Module):
    """Construct the CLS tokens, position and patch embeddings while reusing ViT's initialization utilities."""

    def __init__(self, config: PixoConfig) -> None:
        super().__init__()
        self.patch_embeddings = PixoPatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.randn(1, config.n_cls_tokens, config.hidden_size))
        self.mask_token = None
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + config.n_cls_tokens, config.hidden_size))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.n_cls_tokens = config.n_cls_tokens
        self.patch_size = config.patch_size
        self.config = config

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        num_patches = embeddings.shape[1] - self.n_cls_tokens
        num_positions = self.position_embeddings.shape[1] - self.n_cls_tokens

        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
            return self.position_embeddings

        class_pos_embed = self.position_embeddings[:, : self.n_cls_tokens]
        patch_pos_embed = self.position_embeddings[:, self.n_cls_tokens :]

        dim = embeddings.shape[-1]

        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions = torch_int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        target_dtype = patch_pos_embed.dtype
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.to(torch.float32),
            size=(new_height, new_width),
            mode="bicubic",
            align_corners=False,
        ).to(dtype=target_dtype)

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size, _, height, width = pixel_values.shape
        target_dtype = self.patch_embeddings.projection.weight.dtype
        embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)

        embeddings = self.dropout(embeddings)

        return embeddings


class PixoAttention(ViTAttention):
    pass


class PixoDropPath(Dinov2DropPath):
    pass


class PixoMLP(Dinov2MLP):
    pass


class PixoLayer(GradientCheckpointingLayer):
    def __init__(self, config: PixoConfig) -> None:
        super().__init__()

        self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attention = PixoAttention(config)
        self.drop_path = PixoDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()

        self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = PixoMLP(config)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states_norm = self.norm1(hidden_states)
        self_attention_output = self.attention(hidden_states_norm)

        hidden_states = self.drop_path(self_attention_output) + hidden_states

        layer_output = self.norm2(hidden_states)
        layer_output = self.mlp(layer_output)

        layer_output = self.drop_path(layer_output) + hidden_states

        return layer_output


class PixoEncoder(nn.Module):
    def __init__(self, config: PixoConfig):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([PixoLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(self, hidden_states: torch.Tensor, output_hidden_states: bool = False) -> BaseModelOutput:
        all_hidden_states = [hidden_states] if output_hidden_states else None
        for i, layer_module in enumerate(self.layer):
            hidden_states = layer_module(hidden_states)
            if all_hidden_states:
                all_hidden_states.append(hidden_states)

        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
        )


class PixoPreTrainedModel(ViTPreTrainedModel):
    pass


@auto_docstring
class PixoModel(PixoPreTrainedModel):
    def __init__(self, config: PixoConfig):
        super().__init__(config)
        self.config = config

        self.embeddings = PixoEmbeddings(config)
        self.encoder = PixoEncoder(config)

        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.post_init()

    def get_input_embeddings(self) -> PixoPatchEmbeddings:
        return self.embeddings.patch_embeddings

    @check_model_inputs(tie_last_hidden_states=False)
    @auto_docstring
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs,
    ) -> BaseModelOutputWithPooling:
        if output_hidden_states is None:
            output_hidden_states = self.config.output_hidden_states

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        embedding_output = self.embeddings(pixel_values)

        encoder_outputs: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=output_hidden_states)
        sequence_output = encoder_outputs.last_hidden_state
        sequence_output = self.layernorm(sequence_output)
        pooled_output = sequence_output[:, : self.embeddings.n_cls_tokens, :].mean(dim=1)

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
        )


@auto_docstring(
    custom_intro="""
    Pixo backbone, to be used with frameworks like DETR and MaskFormer.
    """
)
class PixoBackbone(Dinov2Backbone):
    @check_model_inputs
    @auto_docstring
    def forward(
        self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs
    ) -> BackboneOutput:
        r"""
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("facebook/pixo-huge")
        >>> model = AutoBackbone.from_pretrained(
        ...     "facebook/pixo-huge", out_features=["stage7", "stage15", "stage23", "stage31"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 1280, 16, 16]
        ```"""
        if output_hidden_states is None:
            output_hidden_states = self.config.output_hidden_states

        embedding_output = self.embeddings(pixel_values)
        output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
        hidden_states = output.hidden_states

        feature_maps = []
        for stage, hidden_state in zip(self.stage_names, hidden_states):
            if stage in self.out_features:
                if self.config.apply_layernorm:
                    hidden_state = self.layernorm(hidden_state)
                if self.config.reshape_hidden_states:
                    hidden_state = hidden_state[:, self.embeddings.n_cls_tokens :]
                    batch_size, _, height, width = pixel_values.shape
                    patch_size = self.config.patch_size
                    hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
                    hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
                feature_maps.append(hidden_state)

        return BackboneOutput(
            feature_maps=tuple(feature_maps),
            hidden_states=hidden_states if output_hidden_states else None,
        )


__all__ = ["PixoModel", "PixoPreTrainedModel", "PixoBackbone"]

@LiheYoung
Copy link
Copy Markdown
Contributor Author

Thanks @molbap, thanks for your help and sorry for not noticing your message in time. I also just made a commit with a modular file. I tested it locally, and it should be correct. Could you please help review it?

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Dec 15, 2025

run-slow: pixo

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/pixo"]
quantizations: []

@molbap molbap requested a review from yonigozlan December 15, 2025 14:17
@molbap
Copy link
Copy Markdown
Contributor

molbap commented Dec 15, 2025

cc @yonigozlan If you want to take another look before core review!

@LiheYoung
Copy link
Copy Markdown
Contributor Author

@molbap thanks for the instruction! I have just runned make fixup.

@molbap molbap requested a review from ArthurZucker December 17, 2025 07:21
@molbap
Copy link
Copy Markdown
Contributor

molbap commented Dec 17, 2025

run-slow: pixo

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/pixio"]
quantizations: []

Copy link
Copy Markdown
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to me - needs core maintainer review!

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • pixio:
    tests/models/pixio/test_modeling_pixio.py::PixioModelTest::test_batching_equivalence

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Dec 17, 2025

run-slow: pixio

@huggingface huggingface deleted a comment from github-actions Bot Dec 17, 2025
@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/pixio"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Just left a few final comments!
Also, don't forget to add the model to src/transformers/models/__init__.py!

Comment thread docs/source/en/model_doc/pixio.md Outdated
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight converter does not necessarily replace a converter! Better to do it once at the beginning if we can before release, rather than every time later. It's more for post-release integrations!

Comment on lines +196 to +198

if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the more general is_tracing fn, from utils/import_utils.py for data-dependent flow protection

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a common pattern in many many models (43) though, noting to remove all?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, did not notice your questions before! Yes, would make sense to update everywhere!

Comment on lines +207 to +208

sqrt_num_positions = torch_int(num_positions**0.5)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just use a Python int here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, used in 47 other models hehe. so no tracing issues here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure about how torch.jit works exactly, but other ints are used in the same way... Anyway, it has been deprecated so would make sense to go back to ints imo

Comment on lines +311 to +313

def get_input_embeddings(self) -> PixioPatchEmbeddings:
return self.embeddings.patch_embeddings
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still requited @mobalp?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not handled automatically here unfortunately

Comment on lines +218 to +238
@unittest.skip(reason="Pixio does not use inputs_embeds")
def test_inputs_embeds(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should probably not all be skipped!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, they work hehe. Removed

Comment on lines +260 to +266

def test_batching_equivalence(self, atol=1e-4, rtol=1e-4):
super().test_batching_equivalence(atol=atol, rtol=rtol)

@unittest.skip(reason="Pixio does not support feedforward chunking yet")
def test_feed_forward_chunking(self):
pass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, is this true/needed?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batching equivalence: yes, the default tolerance is slightly too small. for ff chunking, no! removing the skip

Comment on lines +268 to +273
@slow
def test_model_from_pretrained(self):
model_name = "LiheYoung/pixio-vith16"
model = PixioModel.from_pretrained(model_name)
self.assertIsNotNone(model)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move to IntegrationTest

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe to remove imo, we already cover from_pretrained in the other test with more scope

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Dec 17, 2025

@Cyrilvallez Addressed your comments I think. LMK!

@molbap molbap requested a review from Cyrilvallez December 17, 2025 11:08
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Very clean! Thanks a lot @LiheYoung and @molbap! Great work 🤗

@molbap molbap enabled auto-merge (squash) December 17, 2025 13:47
@molbap molbap disabled auto-merge December 17, 2025 14:39
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, pixio

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker removed their request for review December 17, 2025 14:51
@molbap molbap removed the request for review from yonigozlan December 17, 2025 14:53
@molbap molbap merged commit a05e0e2 into huggingface:main Dec 17, 2025
25 checks passed
@LiheYoung LiheYoung changed the title Add Pixo pre-trained models Add Pixio pre-trained models Dec 19, 2025
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* Add Pixo

* Add Pixo

* Add test

* Add model_doc

* Add model_doc

* modularize

* modularize more

* Add Pixo

* Add Pixo

* Add test

* Add model_doc

* Add model_doc

* Use modular for Pixo

* missing backbone autodoc

* cleanup

* cleanup

* Revise converting

* rename

* rename

* cleanup

* small test update

* address core review comments

* also docs

* fix

* better with the toctree 👀

---------

Co-authored-by: Pablo Montalvo <pablo.montalvo.leroux@gmail.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants