Skip to content

fix kosmos2 tests#39037

Open
ydshieh wants to merge 3 commits intomainfrom
fix_kosmos2
Open

fix kosmos2 tests#39037
ydshieh wants to merge 3 commits intomainfrom
fix_kosmos2

Conversation

@ydshieh
Copy link
Copy Markdown
Collaborator

@ydshieh ydshieh commented Jun 25, 2025

What does this PR do?

[VLMs] support attention backends (#37576) actually breaks kosmos2 as KosmosTextAttention is used in the decoder (Kosmos2TextBlock) as well as Kosmos2ImageToTextProjection (which should attend to all image places).

But without is_causal, the sdpa_attention_forward will treat it as causal due to

    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

Maybe there is a better way to handle this. But I would not spend too much time but just adding is_causal arugment and pass it.

All tests pass on A10 now

@ydshieh ydshieh requested a review from zucchini-nlp June 25, 2025 15:35
expected_slice = torch.tensor(
[[0.9148, -1.4148, 3.8040], [3.3443, 1.9478, 0.2080], [1.6604, 2.8184, -0.3618]]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this value was obtained from the buggy modeling code . Now since w fix the code, we need to update the value.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do is_causal = True if is_decoder else False, I think that makes sense and removes redundant args passing around

@ydshieh
Copy link
Copy Markdown
Collaborator Author

ydshieh commented Jun 25, 2025

I love the suggestion, unfortunately there is a super nit issue.

class Kosmos2TextBlock

        if config.add_cross_attention:
            self.encoder_attn = KosmosTextAttention(
                config,
                embed_dim=self.embed_dim,
                num_heads=config.attention_heads,
                dropout=config.attention_dropout,
                is_decoder=True,
                add_inner_attn_layernorm=False,
            )

At the time back, I added this block (so a decoder could do self-attention as well as cross-attention to an encoder's output, just like the original Attention is all you need).

Although, this is never used for Kosmos2 ( microsoft uses fairseq to contain a lot of code in the library, not all the paths are used for a particular model).

So if a user have config.add_cross_attention, it will add add_cross_attention to a Kosmos2TextBlock which is a decoder, but it should still pass is_causal = False.

Kosmos2ImageToTextProjection is however has is_decoder = False, so would work with your suggestion.

I would say no hub repository of kosmos2 use config.add_cross_attention=True however. So do you want me to remove that confusing part of code and go with your suggeestion?

@zucchini-nlp
Copy link
Copy Markdown
Member

I would say no hub repository of kosmos2 use config.add_cross_attention=True however. So do you want me to remove that confusing part of code and go with your suggeestion?

Ah I see, so we have a not-used code path. I think we can remove it with a minor deprecation cycle, which is in line with the whole "unbloating" cycle we have currently. And then we can safely assume that non-decoder uses non-causal mask

@ydshieh
Copy link
Copy Markdown
Collaborator Author

ydshieh commented Jun 30, 2025

OK, will do that. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants