Skip to content

Sdpa for owlvit#42136

Merged
vasqu merged 77 commits intohuggingface:mainfrom
Aravind-11:sdpa_for_OWL_ViT
Mar 17, 2026
Merged

Sdpa for owlvit#42136
vasqu merged 77 commits intohuggingface:mainfrom
Aravind-11:sdpa_for_OWL_ViT

Conversation

@Aravind-11
Copy link
Copy Markdown
Contributor

@Aravind-11 Aravind-11 commented Nov 10, 2025

What does this PR do?

Implements SDPA for OWL VIT.

Fixes #28103

Before submitting

Who can review?

@vasqu @younesbelkada

@Aravind-11
Copy link
Copy Markdown
Contributor Author

What does this PR do?

Implements SDPA for OWL VIT. Revamp of #28818

Fixes #28103

Before submitting

Who can review?

@vasqu @younesbelkada

I ran the RUN_SLOW=1 python -m pytest tests/models/owlvit/test_modeling_owlvit.py for the original owlvit implementation and it seemed to fail the same tests as my current implementation. I'm not sure how to infer from that.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but I've got to be strict about this. We no longer implement separate classes for all the attention flavors but one unified one. I think ViT is a good example in this case, e.g. see https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

Before changing this to these standards I won't take a proper look for now.

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Sorry but I've got to be strict about this. We no longer implement separate classes for all the attention flavors but one unified one. I think ViT is a good example in this case, e.g. see https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

Before changing this to these standards I won't take a proper look for now.

Got it. Thanks a lot!

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Sorry but I've got to be strict about this. We no longer implement separate classes for all the attention flavors but one unified one. I think ViT is a good example in this case, e.g. see https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

Before changing this to these standards I won't take a proper look for now.

I made similar changes as in the vit and removed the seperate sdpa class. Let me know what you think!

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments but in general it would be best to have a green CI before requesting a review. Atm, things are likely not working as expected

Comment on lines -716 to -722
causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
# OWL-ViT uses a bidirectional (non-causal) encoder.
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
)
# expand attention_mask
if attention_mask is not None:
# [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to suffer from the same issue as in #41750

It does not use a bidirectional mask, but a causal mask:

  • The first mask is a based causal mask
  • The second is a padding mask
  • These are added on top creating a causal mask with padding included

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also may need to adjust the is_causal argument dynamically as in the PR I linked - although I'm not sure if it's just causal in general

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I made some changes to the code after referring to CLIP - removing the output_attention, return dict and casual_attention_mask. Also copied the eager attention part, attention reshaping from CLIP. Added the flash and flex attn too.

I think that the current CI is failing because the OWL VIT config file is conflicting with the current encoder implementation. Could you guide me here? Thanks a lot!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I made some changes to the code after referring to CLIP - removing the output_attention, return dict and casual_attention_mask. Also copied the eager attention part, attention reshaping from CLIP. Added the flash and flex attn too.

I think that the current CI is failing because the OWL VIT config file is conflicting with the current encoder implementation. Could you guide me here? Thanks a lot!

Hi, I investigated the failing OwlViTForObjectDetectionTest::test_eager_matches_sdpa_inference_09_fp32_pad_left.

The failure is due to the test invoking OwlViTForObjectDetection.forward() without providing pixel_values.

OwlViTForObjectDetection requires pixel_values (image tensors) for its vision backbone. When the test omits them, the model raises a ValueError: 'pixel_values' is None.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, when I run make fix-copies, it's add output_attention and create_causal_mask parameters in owlvitencoderlayer.forward() function.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responded here #42136 (comment)

Resolving my previous comments since the state has changed quite a bit from last time

Comment thread src/transformers/models/owlvit/modeling_owlvit.py
Comment thread src/transformers/models/owlvit/modeling_owlvit.py
Comment thread src/transformers/models/owlvit/modeling_owlvit.py Outdated
Comment thread src/transformers/models/owlvit/modeling_owlvit.py Outdated
Comment thread src/transformers/models/owlvit/modeling_owlvit.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

- Add missing can_return_tuple import to owlvit and owlv2 modeling files
- Remove duplicate _can_record_outputs in OwlViTPreTrainedModel and Owlv2PreTrainedModel
- Remove unused OWLVITModelTesterMixin class from test file

Made-with: Cursor
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42136&sha=c41625

The merge left both old (bmm-based) and new (ALL_ATTENTION_FUNCTIONS)
attention code in OwlViTAttention.forward and Owlv2Attention.forward.
Remove the old dead code that references the deleted _shape method.

Made-with: Cursor
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

- Match CLIP return types: EncoderLayer -> torch.FloatTensor,
  Encoder -> BaseModelOutput
- Align test ConfigTester hidden_size=32 (divisible by num_heads)

Made-with: Cursor
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

@Aravind-11
Copy link
Copy Markdown
Contributor Author

@vasqu pls take a look , thank you!!

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got a few small details, but overall looks good! Thanks a lot for sticking with this, I really didn't make it easy for you as well

Comment on lines +439 to +441
queries = self.q_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
keys = self.k_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
values = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
queries = self.q_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
keys = self.k_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
values = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)

super nit: but that naming is just more standard across the library

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, renamed to query_states/key_states/value_states.

Comment on lines +742 to +748
self.config = config
embed_dim = config.hidden_size

self.embeddings = OwlViTVisionEmbeddings(config)
self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = OwlViTEncoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems not necessary? Or does it come from copies?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted.


# Get image embeddings
last_hidden_state = outputs.vision_model_output[0]
last_hidden_state = outputs.vision_model_output.last_hidden_state
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can break, no? We have no can_return_tuple decorator and if someone may pass return_dict=False, this will fail

Would rather revert these changes here at least

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And below

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +1413 to +1406
input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems breaking to me, any reason we need it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted to original signature.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I put input_ids back as the first required param, the main_input_name = "pixel_values" on OwlViTForObjectDetection no longer matched, causes the test failure.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So i had to :

Remove main_input_name = "pixel_values" from the class

Changed additional_model_inputs in the test from ["input_ids", "attention_mask"] to ["pixel_values", "attention_mask"] — since input_ids is now the main input, the test needs to provide pixel_values as an additional input

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments apply here so not mentioning things twice

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

- Rename queries/keys/values to query_states/key_states/value_states
- Revert VisionTransformer embed_dim local var (unnecessary)
- Revert attribute access (.last_hidden_state, .text_embeds) back to
  index access to avoid breaking with return_dict=False
- Revert ForObjectDetection.forward param order to original

Made-with: Cursor
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42136&sha=8fb57e

Remove main_input_name="pixel_values" from OwlViTForObjectDetection
since forward keeps input_ids first. Update additional_model_inputs
in detection tests to provide pixel_values instead of input_ids.

Made-with: Cursor
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

1 similar comment
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Got a few small details, but overall looks good! Thanks a lot for sticking with this, I really didn't make it easy for you as well

haha, no worries!! thank u for helping out!!! :)))

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 17, 2026

run-slow: owlv2, owlvit

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/owlv2", "models/owlvit"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 31a2e413 workflow commit (merge commit)
PR 1df037b3 branch commit (from PR)
main acc89e74 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@vasqu vasqu enabled auto-merge March 17, 2026 19:42
@vasqu vasqu added this pull request to the merge queue Mar 17, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Merged via the queue into huggingface:main with commit f1f34de Mar 17, 2026
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

OWL-VIT Vision Foundation Model deployment in the edge cases - Need SDPA support for OWL-ViT Model optimization for Edge Deployment

4 participants