Skip to content

Add Videoprism#39895

Open
MHRDYN7 wants to merge 125 commits intohuggingface:mainfrom
MHRDYN7:videoprism
Open

Add Videoprism#39895
MHRDYN7 wants to merge 125 commits intohuggingface:mainfrom
MHRDYN7:videoprism

Conversation

@MHRDYN7
Copy link
Copy Markdown
Contributor

@MHRDYN7 MHRDYN7 commented Aug 4, 2025

Fixes #39893. This pr adds the VideoPrism model by google deepmind. Original repo

@MHRDYN7 MHRDYN7 marked this pull request as ready for review August 23, 2025 21:08
@MHRDYN7
Copy link
Copy Markdown
Contributor Author

MHRDYN7 commented Aug 23, 2025

Summary of the code so far

  • The VideoPrismModel has been implemented using the modular code on top of Vivit and this is only a video encoder.

  • The code design is such that the same encoder code is reused as many times as possible.

  • Batches of video segments of shape (num_frames=16, H=288, W=288) are passed into the model and then the tubelet embedding class is used to convert each frame into hidden states of shape (256, 768) and therefore the whole input becomes (B*num_frame, 256, 768).

  • These spatial embeddings are passed into a spatial encoder

  • The outputs of the spatial encoder are reshaped to (B*256, num_frame, 768) and then passed into a temporal encoder

  • The attention function has an internal attention cap implemented in the modified eager_attention, not sure if this can be somehow used along with sdpa.

  • The VideoPrismClip model uses VideoPrismModel as a backbone for the video input, passes the embeddings through an auxiliary encoder then through an attention pooling layer.

  • There is also a standard text encoder that is called inside VideoPrismClip.

  • All the exact details from the original code have been extracted and correct tensors are being returned for both the models.

@qubvel I'd request your preliminary review on the current code structure. There are detailed comments with "# ?" for ease of review, these will be removed later on.

Todos:

  1. Get started with the tests
  2. Implement the video and text processors
  3. Add code for the video classification model

** Please note that i am currently using the preprocessing utils from the original code in my convert_weights_to_hf script. The code uses mediapy, which uses lanczos interpolation for resizing videos by default. However, lanczos interpolation is not supported in torch yet and that's why we can't get the exact same outputs if a fast video processor is used.

** The videoprism team have not released the weights for the classification head.

** The weights released on hub are in npz format, safetensors need to be uploaded there.

@qubvel qubvel requested review from qubvel and removed request for ArthurZucker and Rocketknight1 August 25, 2025 10:54
Copy link
Copy Markdown
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MHRDYN7, huge thanks for working on the model addition! You already did a great work, please see the comments to align it further to the transformers standards 🤗

Comment on lines +126 to +134
@dataclass
class TextEncoderOutput(ModelOutput):
"""
Base class for text encoder outputs.
"""

last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a common BaseModelOutput, no need to redefine it

Comment on lines +197 to +202
if self.mode == "spatial":
self.patch_embeddings = VideoPrismTubeletEmbeddings(config)
self.spatial_pos_emb = nn.Parameter(torch.zeros(1, self.pos_emb_shape[1] * self.pos_emb_shape[2], config.hidden_size)) # ? (1, 256, 768)

elif self.mode == "temporal":
self.temporal_pos_emb = nn.Parameter(torch.zeros(1, self.pos_emb_shape[0], config.hidden_size)) # ? (1, 16, 768)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have checkpoint releases for both versions? Otherwise, please leave only one.

Comment on lines +257 to +259
def _interpolate_emb_2d(
self, emb: torch.Tensor, source_emb_shape: tuple[int, int], target_emb_shape: tuple[int, int]
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be defined in interpolate_pos_encoding instead, no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't use interpolate_pos_encoding because there are two different types interpolation for the pos embeds (spatial first then temporal). Now that the plan is to have two different embedding classes, I'll inherit from VivitEmbeddings for both and they both will have interpolate_pos_encoding method.

Comment on lines +351 to +357
with torch.no_grad():
self.layernorm_before.weight += nn.Parameter(
torch.ones(self.config.hidden_size)
)
self.layernorm_after.weight += nn.Parameter(
torch.ones(self.config.hidden_size)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems super strange to me. Why do we need this, is this correct? The operation is in-place, so after each forward pass, are we continuing to increase the weight?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it very strange too.

Here is the original code
image
https://github.com/google-deepmind/videoprism/blob/main/videoprism/layers.py#L182-L193

when direct_scale is set to False (this is always the case), +1 is added to the scale tensor of layernorm.
image

This code seems more like an attempt to ensure the layernorm scale factor is not zero during training. I can create an issue in the repo to confirm if they want this behavior during inference as well.

You are right that the hf code in the current form means that the layernorm scale will get increased by +1 for every iteration of forward pass. The jax code initializes the Layernorm class on the go just before it is called so this problem does not happen there. If this +1 portion is moved (from forward) to the init of the relevant class, then that does not work as I guess during creation of a model instance, the init methods are evoked with the initialized weights, and later the pretrained weights are placed and that's why +1 does not happen. I've been working with single forward passes of the model, so it's been fine, but this issue still needs to be resolved.

Copy link
Copy Markdown
Contributor

@qubvel qubvel Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose we must not modify weight inplace and define it as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F

class VideoPrismLayerNorm(nn.LayerNorm):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(
            input, self.normalized_shape, self.weight + 1, self.bias, self.eps
        )

That should be equivalent, right?

Copy link
Copy Markdown
Contributor

@qubvel qubvel Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or in case +1 is only for initialization, we should not have it in modeling code, just in _init_weights method.
Please keep in mind that logits should match exactly (1e-3/1e-4) with the original implementation, and matching them should give you the right answer whether this addition is relevant

P.S. Just saw the message below 👍

Copy link
Copy Markdown
Contributor Author

@MHRDYN7 MHRDYN7 Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, using F.layer_norm is a very good solution; I'll rename the weights and refactor the code. Also I got the confirmation from the videoprism team, scale = 0 for initialization, and the +1 is expected to be there.

Comment on lines +366 to +375
if mode == "spatial":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)])
elif mode == "temporal":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)])
elif mode == "auxiliary":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_auxiliary_layers)])
elif mode == "unimodal":
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_unimodal_layers)])
else:
raise ValueError(f"Unknown mode: {mode}. Supported modes are: spatial, temporal, auxiliary and unimodal.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason to split it into different if/else. We might have only one attribute num_layers and that's it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove these conditionals, then we will need to set config.num_layers = config.num_spatial_layers, before instantiating the spatial encoder object and then config.num_layers = config.num_temporal_layers before temporal encoder in the model init. It's slightly less explicit, but, since you agreed with the change of config from python_gelu to relu in the init, I'll go ahead with this one too, and then the encoder can be directly used from vivit.

Comment on lines +641 to +642
with torch.no_grad():
self.layernorm.weight += nn.Parameter(torch.ones(self.config.hidden_size))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, super strange

Comment on lines +652 to +675
class PositionalEmbedding(nn.Module):
def __init__(self, config: VideoPrismConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.min_timescale = 1
self.max_timescale = 10000

def forward(self, seq_length):
position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) # ? (1, seq_length)
num_timescales = self.hidden_size // 2

log_timescale_increment = math.log(
float(self.max_timescale) / float(self.min_timescale) # ? log(10000/1) = ln(10000)
) / torch.maximum(torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1))

inv_timescales = self.min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment
)

scaled_time = position.unsqueeze(-1) * inv_timescales.expand(1, 1, -1)

embs = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=-1)

return embs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

align with the RoPE classes defined in transformers

Comment thread src/transformers/models/videoprism/modular_videoprism.py Outdated
Comment on lines +748 to +751
self.backbone = VideoPrismModel(config)
self.auxiliary_encoder = VideoPrismEncoder(config, mode="auxiliary")
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(config)
self.text_encoder = VideoPrismTextEncoder(config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we should have the following classes (similar to CLIP)

  • VideoPrismVideoModel
  • VideoPrismTextModel

please, align to this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be refactored to follow mllama or dinov3_vit conversion script format. We should have a weights mapping dict and manipulate that

@MHRDYN7
Copy link
Copy Markdown
Contributor Author

MHRDYN7 commented Aug 27, 2025

The following notebooks are from the original repo and the expected values of the output tensors used to validate the HF implementation have been taken from here.
Video encoder model notebook
Video Text model notebook

Please note that the output tensors are slightly different for the larger models when a TPU is used. My expected tensor values are for the cpu.

Since the output of the final tensors (of HF code) matches that of the original code, the elements in the current code are aligned with the original implementation, despite that strange layernorm weight increase.

@zucchini-nlp
Copy link
Copy Markdown
Member

run-slow: videoprism

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/videoprism"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN bb19e5c0 workflow commit (merge commit)
PR dd80ec0b branch commit (from PR)
main f38d6639 base commit (on main)

Model CI Report

3 new failed tests from this PR 😭

  • videoprism:
    tests/models/videoprism/test_modeling_videoprism.py::VideoPrismModelIntegrationTest::test_videoprism_classification_model (✅ ⟹ ❌)
    tests/models/videoprism/test_modeling_videoprism.py::VideoPrismModelIntegrationTest::test_videoprism_clip_model (✅ ⟹ ❌)
    tests/models/videoprism/test_modeling_videoprism.py::VideoPrismModelIntegrationTest::test_videoprism_vision_model (✅ ⟹ ❌)

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 2, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism, vivit

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 2, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism, vivit

@MHRDYN7
Copy link
Copy Markdown
Contributor Author

MHRDYN7 commented Apr 2, 2026

@zucchini-nlp, I investigated the three failing tests on gpu to the core. There was really no issue (except one) so added the cuda logits in the tests. The tolerance might need to be changed as these values are on a L4 machine. Also, please note the the last slow tests skipped FA2 tests (although FA is enabled for the model), but on my machine the FA2 tests still have some logits that don't fall within the tolerances. Let me know if these FA tests are supposed to be skipped or not

@zucchini-nlp
Copy link
Copy Markdown
Member

Yeah these tests usually fail due to numerical diff in hardware. We can help you get the correct expectations later on, so you can copy-paste

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this is a long review but we may have to wait for the vit refactor first tbh (I've linked it somewhere in the review)

This goes over the modular code, not the tests yet. It's often little details to make this more aligned to what we have. A lot can also be done even before the refactor imo

# Video Processor

A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch. Along ith transformations the `VideoProcessor` class handles video decoding from local paths or URLs (requires [`torchcodec`](https://pypi.org/project/torchcodec/)) and frame sampling according to model-specific strategies.
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch. Along with transformations, the `VideoProcessor` class also handles video decoding from local paths or URLs (requires [`torchcodec`](https://pypi.org/project/torchcodec/)) and frame sampling according to model-specific strategies.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to move to a different docs PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove it from this PR, but I have a suggestion. It would be very nice if a maintainer managed a running issue + pr to collect all these little typos from the community and merged these little improvements biweekly/monthly. This is perhaps a more organized solution than having a bunch of '2 line change' PRs and also does not discourage user to report such issues from time to time.

Comment thread docs/source/en/model_doc/videoprism.md Outdated
@@ -0,0 +1,123 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
<!--Copyright 2026 The HuggingFace Team. All rights reserved.

it's been a while :D likely elsewhere marking once here only

Comment thread docs/source/en/model_doc/videoprism.md Outdated
Comment on lines +64 to +66
processed_video_inputs = processor(videos=[video_url], return_metadata=True, do_sample_frames=True)
video_metadata = processed_video_inputs["video_metadata"]
video_inputs = processed_video_inputs["pixel_values_videos"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to move the inputs to model.device? In case its cuda or similar, I suspect this would return cpu no?

@@ -0,0 +1,1094 @@
from collections.abc import Callable
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

licence missing

@@ -0,0 +1,565 @@
import argparse
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

licence

Comment on lines +907 to +908
self.auxiliary_encoder = VideoPrismAuxiliaryEncoder(config)
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not merge these into one class? Or does it not make sense?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because contrastive_vision_pooler is also used in the video classification model.

Comment on lines +951 to +952
if not isinstance(config, VideoPrismConfig):
raise TypeError(f"`config` is expected to be of type `VideoPrismConfig` but is of type {type(config)}.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(config, VideoPrismConfig):
raise TypeError(f"`config` is expected to be of type `VideoPrismConfig` but is of type {type(config)}.")

just makes the code harder to read

Comment on lines +992 to +995
if video_emb_dim != text_emb_dim:
raise ValueError(
f"Dimension of video ({video_emb_dim}) and text ({text_emb_dim}) embeddings must match for similarity computation."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if video_emb_dim != text_emb_dim:
raise ValueError(
f"Dimension of video ({video_emb_dim}) and text ({text_emb_dim}) embeddings must match for similarity computation."
)

I think youll notice I dont like to error out too much :D this would be a config issue then if anything

nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()

return VideoPrismClipOutput(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should follow siglip a bit closer and have pooler outputs it should then be mostly 1:1

)
super().__init__(config)
self.config = config
self.encoder = VideoPrismVisionModel._from_config(self.config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should match with base model prefix

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean self.vision_model instead of self.encoder? ok will update the weights

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 6, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism, vivit

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 9, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism, vivit

1 similar comment
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 9, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism, vivit

@MHRDYN7
Copy link
Copy Markdown
Contributor Author

MHRDYN7 commented Apr 10, 2026

@vasqu thank you very much the detailed review. Certainly this PR needs to wait for the VIT refactor, which should help reduce LOC a lot. Many of the suggestions have been applied and in some cases, further consideration is required. Please see my comments and the reasoning.

In summary, the three key points to decide are

  1. sdpa should stay or not (other models have kept it despite softcap)
  2. MultiheadAttentionPoolingHead (This name has precedent + I'll implement this following the existing pattern phi4multimodal instead of your suggestions on top of my current implementation)
  3. Possibly refactor ClipEmbedding model to accommodate sinusoidal pos embeds so that we can directly import that class in modular file. All the design issues pointed out here stem from that clip class and these propagate via copied from.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 15, 2026

@MHRDYN7 Answering your concerns here but will take a look at the vit refactor to get it in first 🫡

  1. Imo, it kind of depends on the sensitivity of the model. We can keep it but then need to raise our tols and indicate that in our tests with proper comments why
  2. Sounds good, if there is precedence I'd rather follow that then indeed. However, definitely the version that keeps the attention interface and does not use the nn.MultiHeadAttention from torch (it does not play nicely within our designs and should be avoided)
  3. If it really is that different, dw about not having perfect modular there then. These are often ideas I have, not fool proof. I'd rather not force change clip for this

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism, vivit

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=39895&sha=fe2226

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 21, 2026

Just as a small update: The vit refactor is getting closer to the finish line. We are now just making sure to fix the last details and tests 🤗

@MHRDYN7
Copy link
Copy Markdown
Contributor Author

MHRDYN7 commented Apr 21, 2026

Just as a small update: The vit refactor is getting closer to the finish line. We are now just making sure to fix the last details and tests 🤗

yes I am monitoring that PR, will pull the changes as soon as it is merged with main

@MHRDYN7
Copy link
Copy Markdown
Contributor Author

MHRDYN7 commented Apr 21, 2026

@zucchini-nlp I tried to run the utils.check_auto.py script and there are some undesirable changes in auto_mappings.py like

- ("detr", "DetrConfig"),
+ ("detr", "MaskFormerDetrConfig"),

-("sam3_video", "Sam3VisionConfig"),
-("sam3_vit_model", "Sam3ViTConfig"),
+("sam3_vision_model", "Sam3LiteTextVisionConfig"),
+("sam3_vit_model", "Sam3LiteTextViTConfig"),

-("sam3_vision_model", "sam3"),
-("sam3_vit_model", "sam3"),
+("sam3_vision_model", "sam3_lite_text"),
+("sam3_vit_model", "sam3_lite_text"),

where the values get reassigned when newer models have the model_type of existing models.
The simple fix that prevents this is changing the following lines

if model_type != module_name:
special_mappings[model_type] = module_name
model_type_map[model_type] = config_cls_name

to this

    if model_type != module_name:
        if model_type not in special_mappings:
            special_mappings[model_type] = module_name 

    if model_type not in model_type_map:
        model_type_map[model_type] = config_cls_name

This logic of checking existing key worked for me, but I haven't fully read all the other parts of the script. Please let me if the changes are to be made, and if in a separate short PR.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 21, 2026

possibly re #45535 cc @yonigozlan, we might need to override the model type in the configs(?). Seems like that branch was out of sync maybe with main

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, videoprism, vivit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add VideoPrism

7 participants