🚨 [Attn] New attn mask interface everywhere#42848
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: gpt2,mllama,opt,biogpt,blt,decision_transformer |
|
This comment contains models: ["models/biogpt", "models/blt", "models/decision_transformer", "models/gpt2", "models/mllama", "models/opt"] |
CI ResultsModel CI Report❌ Failed tests
|
|
run-slow: bert, bert_generation, blt, data2vec, decision_transformer, electra, ernie, glm46v, glm4v, glm4v_moe, gpt2, mllama, opt, paddleocr_vl, roberta, roberta_prelayernorm |
|
This comment contains models: ["models/bert", "models/bert_generation", "models/blt", "models/data2vec", "models/decision_transformer", "models/electra", "models/ernie", "models/glm46v", "models/glm4v", "models/glm4v_moe", "models/gpt2", "models/mllama", "models/opt", "models/paddleocr_vl", "models/roberta", "models/roberta_prelayernorm"] |
CI ResultsModel CI Report❌ Failed tests
|
FA] Fix paddingfree tests to properly consider position ids and default create a maskAttn] More new interface switches and proper paddingfree test
|
run-slow: bamba,falcon_h1,mllama,moshi,zamba2,zamba,pop2piano |
|
This comment contains models: ["models/bamba", "models/falcon_h1", "models/mllama", "models/moshi", "models/pop2piano", "models/zamba", "models/zamba2"] |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer, detr, falcon, falcon_h1 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer, detr, falcon, falcon_h1 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer, detr, falcon, falcon_h1 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer, detr, falcon, falcon_h1 |
vasqu
left a comment
There was a problem hiding this comment.
Self-review on points that may seem weird to clarify a bit
There was a problem hiding this comment.
Just reapplied modular
| attention_mask: torch.Tensor | None = None, | ||
| causal_attention_mask: torch.Tensor | None = None, | ||
| output_attentions: bool | None = False, | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, torch.Tensor | None]: |
There was a problem hiding this comment.
You will see this pattern a few times and it comes from an old clip implementation which did the padding mask and causal (naive triu) separately and then added them up - we do this ourselves at the same time
| kwargs.pop("is_causal", None) | ||
| encoder_outputs = self.encoder( | ||
| inputs_embeds=hidden_states, | ||
| attention_mask=attention_mask, | ||
| causal_attention_mask=causal_attention_mask, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| is_causal=True, | ||
| **kwargs, |
There was a problem hiding this comment.
Technically not needed in a few models but it's the same as in clip so I'd rather preemptively add correct kwargs if someone refactors these models
| if not module.is_cross_attention: | ||
| # if only "normal" attention layer implements causal mask | ||
| query_length, key_length = query.size(-2), key.size(-2) | ||
| causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length] | ||
| mask_value = torch.finfo(attn_weights.dtype).min | ||
| # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. | ||
| # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` | ||
| mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) | ||
| attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) |
There was a problem hiding this comment.
Mentioning it here but it comes from gpt2 - what essentially happened back in the days was to create the padding mask and then add a triu as buffer on top aka padding + causal (might remind you of what happened with the clip likes)
I have checked to ignore these on load now and results are the same
| causal_mask = create_causal_mask( | ||
| config=self.config, | ||
| input_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| cache_position=cache_position, | ||
| past_key_values=past_key_values, | ||
| # Force mask creation for alibi | ||
| and_mask_function=lambda *args: torch.tensor(True, dtype=torch.bool), | ||
| ) | ||
| if alibi is not None and causal_mask is not None and causal_mask.ndim == 4: | ||
| min_dtype = torch.finfo(inputs_embeds.dtype).min | ||
|
|
||
| # Only using non-bool mask for alibi | ||
| if causal_mask.dtype == torch.bool: | ||
| causal_mask = torch.where( | ||
| causal_mask, torch.tensor(0.0, device=causal_mask.device, dtype=inputs_embeds.dtype), min_dtype | ||
| ) | ||
|
|
||
| # We take care to integrate alibi bias in the causal_mask here | ||
| alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) | ||
| causal_mask = torch.masked_fill( | ||
| alibi / math.sqrt(self.config.hidden_size // self.num_heads), | ||
| causal_mask < -1, | ||
| min_dtype, | ||
| ) |
There was a problem hiding this comment.
This is a bit special but I went as far as I could with what is available - alibi is applied on top of the mask so it needs a float mask
There was a problem hiding this comment.
yep, I don't really think our API supports this better than what you did. Tho this only works if the mask is not a BlockMask from flex right?
There was a problem hiding this comment.
Yup, but the flags don't support flex so we are fine
| # We need to prepare position ids according to the attention mask as we use it to extract embeddings that | ||
| # rely on the correct position - naively increasing sequences do not suffice anymore atp. The solution here | ||
| # calculates an increasing sequences for all 1s and puts 0s else. | ||
| inputs_dict["position_ids"] = ((inputs_dict["attention_mask"] == 1).long().cumsum(dim=1) - 1) * ( | ||
| inputs_dict["attention_mask"] == 1 | ||
| ).long() | ||
|
|
There was a problem hiding this comment.
This is super important and allows native support for absolute position embeddings like bert without overwriting their tests
There was a problem hiding this comment.
It makes the tests work but unsure if we want this
| device=torch_device, | ||
| ) | ||
| inputs_dict["input_ids"] = inputs_dict["labels"] | ||
| inputs_dict["attention_mask"] = torch.tril(torch.ones_like(inputs_dict["input_ids"]).to(torch_device)) |
There was a problem hiding this comment.
This allows us to test gpt2 on the padding free tests - it was skipped before
There was a problem hiding this comment.
Low usage, not worth to fix imo with a lot of custom stuff happening - it's the most closely related model to the old API
| # the following models should have been PreTrainedModels | ||
| "Owlv2TextTransformer", | ||
| "Owlv2VisionTransformer", | ||
| "OwlViTTextTransformer", | ||
| "OwlViTVisionTransformer", | ||
| "XCLIPTextTransformer", | ||
| "CLIPSegTextTransformer", | ||
| "DetrDecoder", | ||
| "GroupViTTextTransformer", | ||
| "CLIPTextTransformer", | ||
| "CLIPVisionTransformer", | ||
| "MetaClip2TextTransformer", | ||
| "MetaClip2VisionTransformer", | ||
| "MLCDVisionTransformer", | ||
| # end of should have beens |
There was a problem hiding this comment.
Same as in Cyril's PR but now for attention related things :)
ArthurZucker
left a comment
There was a problem hiding this comment.
Huge cleanup, much welcome.
Got the idea with passing some models to pretrained ones, not super sure we should vs placing the mask creation code in the parent class that uses it
There was a problem hiding this comment.
that's a really nice cleanup!
| causal_mask = create_causal_mask( | ||
| config=self.config, | ||
| input_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| cache_position=cache_position, | ||
| past_key_values=past_key_values, | ||
| # Force mask creation for alibi | ||
| and_mask_function=lambda *args: torch.tensor(True, dtype=torch.bool), | ||
| ) | ||
| if alibi is not None and causal_mask is not None and causal_mask.ndim == 4: | ||
| min_dtype = torch.finfo(inputs_embeds.dtype).min | ||
|
|
||
| # Only using non-bool mask for alibi | ||
| if causal_mask.dtype == torch.bool: | ||
| causal_mask = torch.where( | ||
| causal_mask, torch.tensor(0.0, device=causal_mask.device, dtype=inputs_embeds.dtype), min_dtype | ||
| ) | ||
|
|
||
| # We take care to integrate alibi bias in the causal_mask here | ||
| alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) | ||
| causal_mask = torch.masked_fill( | ||
| alibi / math.sqrt(self.config.hidden_size // self.num_heads), | ||
| causal_mask < -1, | ||
| min_dtype, | ||
| ) |
There was a problem hiding this comment.
yep, I don't really think our API supports this better than what you did. Tho this only works if the mask is not a BlockMask from flex right?
| output_attentions: bool | None = None, | ||
| output_hidden_states: bool | None = None, | ||
| return_dict: bool | None = None, | ||
| **kwargs, |
There was a problem hiding this comment.
I think we should typeDict enforce the kwargs for all the onces you added.
|
|
||
|
|
||
| class GroupViTTextTransformer(nn.Module): | ||
| class GroupViTTextTransformer(GroupViTPreTrainedModel): |
There was a problem hiding this comment.
that's weird because the config object should be just passed to the encoder no (meaning changing it globally not for GroupViTTextTransformer but the one that has GroupViTTextTransformer.
But yeah not a big deal a lot of these vision models have shitty design with wrapper around wrappers
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
| attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) | ||
|
|
||
| kwargs.pop("is_causal", None) |
There was a problem hiding this comment.
arf not super super clean but fine
There was a problem hiding this comment.
Yea, it's not ideal tbh :(
| class Wav2Vec2BertAdapterLayer(nn.Module): | ||
| def __init__(self, config): | ||
| super().__init__() | ||
| self.config = config |
There was a problem hiding this comment.
bit wird not to have this one go to pretrainedModel but no worries
There was a problem hiding this comment.
I have rechecked those to not need it in 28f8a74 - initially did pretrained models there too but they have proper top modules which handle setting the attention etc
| self.assertEqual(position_ids.shape, expected_positions.shape) | ||
| self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) | ||
|
|
||
| def attention_mask_padding_matches_padding_free_with_position_ids( |
There was a problem hiding this comment.
So the attention_mask_padding_matches_padding_free_with_position_ids from GenerationTesterMixin now works, that's actually cool ty!
There was a problem hiding this comment.
Yup, the position ids are way more sensitive to the absolute position embeddings than rope so this was silently having wrong positions at times
| @unittest.skip(reason="doesn't support padding yet") | ||
| def test_eager_matches_sdpa_inference_1_bfloat16(self): | ||
| # TODO: vasqu | ||
| @unittest.skip(reason="why the heck does this have bigger tols") |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer, detr, falcon, falcon_h1 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: align, altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, chinese_clip, clap, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: align, altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, chinese_clip, clap, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer |
1 similar comment
|
[For maintainers] Suggested jobs to run (before merge) run-slow: align, altclip, autoformer, bamba, bark, bigbird_pegasus, bloom, blt, chinese_clip, clap, clipseg, clvp, codegen, conditional_detr, dab_detr, decision_transformer |
* fix * fix order * style * vision 3d rope get extra test for now * fix gpt2 * more gpt2 fixes * let's see... * fix * test * fix opt+biogpt * fix * fix * fix * fix opt * mask exchange test * style * several small fixes * shouldnt be needed * fix zamba models * retrigger ci * force skip for now * this wont work, will fix step by step * to git * another batch * fix a few models, clip related models are gonna be hard... * another batch * style * fix gpt2 attempt * another batch + some models do not set their attn implementation? TODO * fix * last models * style * repo fix * check * some quick fixes, error to catch wrong inits in some models * small fixes * fixes for wrong mask pretrained model relation * fix * remove mask defaulting --> that's part of the prep + fixup some other tests * small fixes * fix last few models --> last to check recurrent gemma + repo consistency * fixup test cleanup * revert these tests * these were not necessary, they have a proper top module * fixup kwargs * remove old API * more kwargs * let's revert this - im in a fork :D * fix * dang * revert removal and add deprecation msg * kwargs typing * style
As per title ~