Skip to content

🔴[Attention] Attention refactor for Whisper-based models#38235

Merged
vasqu merged 36 commits intomainfrom
vas-whisper-attn-refactor
May 28, 2025
Merged

🔴[Attention] Attention refactor for Whisper-based models#38235
vasqu merged 36 commits intomainfrom
vas-whisper-attn-refactor

Conversation

@vasqu
Copy link
Copy Markdown
Contributor

@vasqu vasqu commented May 20, 2025

Whisper attention refactor according to the same strategies applied in #38108

Also, several fixes on Whisper along the way, reducing the number of failed tests to 3 (disregarding skipped tests):

  • test_small_longform_timestamps_generation
  • test_tiny_token_timestamp_batch_generation
  • test_whisper_longform_multi_batch_hard_prev_cond

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu marked this pull request as ready for review May 21, 2025 12:53
@github-actions github-actions Bot requested review from ArthurZucker and eustlb May 21, 2025 12:54
@vasqu vasqu requested a review from gante May 21, 2025 12:55
Comment thread src/transformers/cache_utils.py
Comment thread src/transformers/generation/logits_process.py
Comment thread src/transformers/generation/utils.py Outdated
@vasqu vasqu requested a review from gante May 22, 2025 16:03
@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented May 22, 2025

Ready for another round of reviews imo

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!
What are the compilation issues with flex? Pretty sure updating to the new causal mask would solve!

Comment on lines +601 to +603
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will be noisy but let's use the new cache creation no? 👀

Copy link
Copy Markdown
Contributor Author

@vasqu vasqu May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's introducing even more issues (torchscript + another integration test is starting to fail then - test_tiny_static_generation_long_form). Would leave it for now if that's ok. Don't want to make whisper even more broken tbh, wdyt?

cc @Cyrilvallez for the new mask creation, tried replacing it with

# previously
#causal_mask = self._update_causal_mask(
#    attention_mask,
#    inputs_embeds,
#    cache_position,
#    past_key_values.self_attention_cache if past_key_values is not None else None,
#)

causal_mask = create_causal_mask(
    config=self.config,
    input_embeds=inputs_embeds,
    attention_mask=attention_mask,
    cache_position=cache_position,
    past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
)

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchscript with a mask is a known issue, but it's not super important anyway - for the other one, if it's compile-related it will probably be fixed in #38319 - TLDR I had introduced a workaround for Python<3.11 on torch.export, but it's also an issue with compile and fullgraph=True for those same python versions (which was obvious as they use the same dynamo tracing, but I missed it 🙃), so made the workaround the default

Copy link
Copy Markdown
Contributor Author

@vasqu vasqu May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to apply the fix locally and check if it works but sounds good :)

For torchscript, I would skip those tests and add a todo for you (if we try to make it work again). Agree that it's not the most important feature. Wdyt?

Copy link
Copy Markdown
Contributor Author

@vasqu vasqu May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it seems the failures are unrelated to the compiling issues. It seems that encoder-decoder compiling relies on the mask functions to be under PretrainedXXX, which leads to the integration test failing. This needs a closer look tbh, will leave it to a future PR to address the new masking.

cc @gante @zucchini-nlp if you know anything about the encoder-decoder caches relying on the functions under PretrainedXXX

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit late to the party sorry, but if you use the new mask API you should get rid of _prepare_4d_causal_attention_mask_with_cache_position entirely everywhere, otherwise generate will use it instead of the new create_masks_for_generate! Just check a bit further here 🤗

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Otherwise generate will not create the correct mask with flex, or custom attention, and you're mixing new and old API which is not good)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may need to overwrite the general one though, to account for the EncoderDecoderCache, but it can be done super easily as in Gemma3 for example, where we need to overwrite to account for the additional mask for the image tokens in training

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK on slack if you cannot make it work, but TLDR we should not mix the old/new mask APIs, and the new one will be more general as flex will work correctly!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
        """
        Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
        the given layer at `layer_idx`.
        The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
        for each layer.
        """
        return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)

in EncoderDecoderCache is probably even better/cleaner, as you don't need to overwrite create_masks_for_generate, and it will always work independently of the type of Cache being used

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented May 23, 2025

Investigating what's happening with flex attention: Something weird is going on 👀

Edit: Found the issue, gonna open another PR since it affects more models

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented May 23, 2025

Found the root cause behind flex attention failing, will open another PR for this which should be merged before this PR. See #38321

@Cyrilvallez
Copy link
Copy Markdown
Member

Found the root cause behind flex attention failing, will open another PR for this which should be merged before this PR.

If you're talking about compilation with flex, it's a known issue as well, as flex auto-compiles itself it seems to interfere when the forward is compiled as well 🥲 Super nice if you found a workaround!

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented May 23, 2025

If you're talking about compilation with flex, it's a known issue as well, as flex auto-compiles itself it seems to interfere when the forward is compiled as well 🥲 Super nice if you found a workaround!

Ah, no I think I'm talking about something different. We have one flex attention test and it was failing for like 99% of the models. Fixing this in another PR. Iiuc, then it's basically an issue with torch compile doubly compiling as we compile forward and flex tries to compile again - that's indeed not nice 👀

Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more nit

(whisper is fun :D)

input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
Copy link
Copy Markdown
Contributor

@gante gante May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to type check this statement and add more logic?

if we follow the nesting from WhisperForCausalLM (-> WhisperDecoderWrapper -> WhisperDecoder [this class]), then past_key_values can also be a decoder-only cache and past_key_values.self_attention_cache is fail-prone

this also means:

  • cache initialization above is incomplete
  • the docstring for past_key_values is imprecise

☠️

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, fair point lemme check what's even happening here 👀

Copy link
Copy Markdown
Contributor Author

@vasqu vasqu May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so it seems to me that we create an EncoderDecoderCache even if we use the decoder-only model. And, it's not only for whisper the case but also for any encoder-decoder model that has a decoder-only model flavor (with cache class support).

In short, this handles our cases:

if use_cache or past_key_values is not None:
    if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
        return_self_attention_cache = True
        past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
    elif not isinstance(past_key_values, EncoderDecoderCache):
        return_legacy_cache = True
        logger.warning_once(
            "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
            "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
            "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
        )
        past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  • Decoder-only cache is passed --> internally we use an encoder-decoder cache --> return decoder-only again after everything
  • Defaulting to encoder-decoder in any other case

We could imo add to the docs that decoder-only is possible? Not necessarily a fan of this tbh, but I can see it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point!

I think a cleaner approach would move the conversion logic to the decoder-only classes. In other words, the shared classes always assume EncoderDecoderCache, the decoder-only AutoModelFor... classes hold the conversion logic. This would keep the reference classes and their docs as clean as possible, with the expansions (decoder-only) being responsible for the adaptation.

WDYT?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the thing is, this can be completely done inside the EncoderDecoderCache.
SOmething like:

class EncoderDecoderCache:
     def __init__(self, past_key_values):
        if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
            ... # do the init 
            return self
        if not isinstance(past_key_values, EncoderDecoderCache):
            return self.from_legacy_cache(past_key_values)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then the only check in the model to do is :

if use_cache or past_key_values is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker that handles cache initialization, but doesn't solve the part where the whisper class may selectively return a Cache in certain circumstances (as opposed to an EncoderDecoderCache).

IMO, the root issue comes from the PR that introduced this Cache<>EncoderDecoderCache logic: to minimize cross-class dependencies, the more recent decoder-only model should wrap its Cache into an EncoderDecoderCache before calling the main trunk of Whisper. This is a solution without if/elses where all main classes can be kept without cross-references :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this change to the cache is out of scope for this PR. It's affecting multiple models not related to whisper (e.g. bart). It would make more sense to tackle this in a separate PR that covers all of these models.

There are either two options imo:

  1. Update cache logic to move the decoder-only cache logic out (Joao's suggestion)
  2. Update docstrings to reflect the logic that happens here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, makes sense to do it in another PR :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in general we want to move away from sub modules returning cache classes!

Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Happy with the PR, with the exception of the related but out-of-scope issue discussed above)

@Cyrilvallez
Copy link
Copy Markdown
Member

Please see my comments #38235 (comment) here before merging! I believe we can make it much cleaner to set a proper example for all similar models!

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just left 2 last comments to try to simplify the code as much as possible (especially the part on general FA2 attention as it would unnecessarily impact every model here I think)! Thanks for making the changes 🤗

Comment on lines +35 to +36
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, : key.shape[-2]]
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed based on latest mask refactors in the modeling? It should already have been taken care of upstream in the mask creation function

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be an old relic from the mask creations which caused this :D removed it

Comment on lines +936 to +937
cache_position=cache_position,
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should not need the if/else with the change to the Cache itself 🤗

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup good point! Also changed to just return past_kvs

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very happy with the changes! 🤗

cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
causal_mask = create_causal_mask(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha damn the change is soooo much cleaner!

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented May 28, 2025

Running slow tests locally and then merge (if nothings broken)

Edit: slow tests pass as expected

@vasqu vasqu merged commit badc71b into main May 28, 2025
21 checks passed
@vasqu vasqu deleted the vas-whisper-attn-refactor branch May 28, 2025 11:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants