🔴[Attention] Attention refactor for Whisper-based models#38235
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Ready for another round of reviews imo |
ArthurZucker
left a comment
There was a problem hiding this comment.
Very nice!
What are the compilation issues with flex? Pretty sure updating to the new causal mask would solve!
| # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask | ||
| def _update_causal_mask( | ||
| self, |
There was a problem hiding this comment.
I will be noisy but let's use the new cache creation no? 👀
There was a problem hiding this comment.
It's introducing even more issues (torchscript + another integration test is starting to fail then - test_tiny_static_generation_long_form). Would leave it for now if that's ok. Don't want to make whisper even more broken tbh, wdyt?
cc @Cyrilvallez for the new mask creation, tried replacing it with
# previously
#causal_mask = self._update_causal_mask(
# attention_mask,
# inputs_embeds,
# cache_position,
# past_key_values.self_attention_cache if past_key_values is not None else None,
#)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
)There was a problem hiding this comment.
torchscript with a mask is a known issue, but it's not super important anyway - for the other one, if it's compile-related it will probably be fixed in #38319 - TLDR I had introduced a workaround for Python<3.11 on torch.export, but it's also an issue with compile and fullgraph=True for those same python versions (which was obvious as they use the same dynamo tracing, but I missed it 🙃), so made the workaround the default
There was a problem hiding this comment.
I'll try to apply the fix locally and check if it works but sounds good :)
For torchscript, I would skip those tests and add a todo for you (if we try to make it work again). Agree that it's not the most important feature. Wdyt?
There was a problem hiding this comment.
Ok, it seems the failures are unrelated to the compiling issues. It seems that encoder-decoder compiling relies on the mask functions to be under PretrainedXXX, which leads to the integration test failing. This needs a closer look tbh, will leave it to a future PR to address the new masking.
cc @gante @zucchini-nlp if you know anything about the encoder-decoder caches relying on the functions under PretrainedXXX
There was a problem hiding this comment.
A bit late to the party sorry, but if you use the new mask API you should get rid of _prepare_4d_causal_attention_mask_with_cache_position entirely everywhere, otherwise generate will use it instead of the new create_masks_for_generate! Just check a bit further here 🤗
There was a problem hiding this comment.
(Otherwise generate will not create the correct mask with flex, or custom attention, and you're mixing new and old API which is not good)
There was a problem hiding this comment.
You may need to overwrite the general one though, to account for the EncoderDecoderCache, but it can be done super easily as in Gemma3 for example, where we need to overwrite to account for the additional mask for the image tokens in training
There was a problem hiding this comment.
LMK on slack if you cannot make it work, but TLDR we should not mix the old/new mask APIs, and the new one will be more general as flex will work correctly!
There was a problem hiding this comment.
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)in EncoderDecoderCache is probably even better/cleaner, as you don't need to overwrite create_masks_for_generate, and it will always work independently of the type of Cache being used
|
Investigating what's happening with flex attention: Something weird is going on 👀 Edit: Found the issue, gonna open another PR since it affects more models |
|
Found the root cause behind flex attention failing, will open another PR for this which should be merged before this PR. See #38321 |
If you're talking about compilation with flex, it's a known issue as well, as flex auto-compiles itself it seems to interfere when the forward is compiled as well 🥲 Super nice if you found a workaround! |
Ah, no I think I'm talking about something different. We have one flex attention test and it was failing for like 99% of the models. Fixing this in another PR. Iiuc, then it's basically an issue with torch compile doubly compiling as we compile forward and flex tries to compile again - that's indeed not nice 👀 |
gante
left a comment
There was a problem hiding this comment.
One more nit
(whisper is fun :D)
| input_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| cache_position=cache_position, | ||
| past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, |
There was a problem hiding this comment.
do we need to type check this statement and add more logic?
if we follow the nesting from WhisperForCausalLM (-> WhisperDecoderWrapper -> WhisperDecoder [this class]), then past_key_values can also be a decoder-only cache and past_key_values.self_attention_cache is fail-prone
this also means:
- cache initialization above is incomplete
- the docstring for
past_key_valuesis imprecise
☠️
There was a problem hiding this comment.
Yea, fair point lemme check what's even happening here 👀
There was a problem hiding this comment.
Ok, so it seems to me that we create an EncoderDecoderCache even if we use the decoder-only model. And, it's not only for whisper the case but also for any encoder-decoder model that has a decoder-only model flavor (with cache class support).
In short, this handles our cases:
if use_cache or past_key_values is not None:
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)- Decoder-only cache is passed --> internally we use an encoder-decoder cache --> return decoder-only again after everything
- Defaulting to encoder-decoder in any other case
We could imo add to the docs that decoder-only is possible? Not necessarily a fan of this tbh, but I can see it.
There was a problem hiding this comment.
Ah, good point!
I think a cleaner approach would move the conversion logic to the decoder-only classes. In other words, the shared classes always assume EncoderDecoderCache, the decoder-only AutoModelFor... classes hold the conversion logic. This would keep the reference classes and their docs as clean as possible, with the expansions (decoder-only) being responsible for the adaptation.
WDYT?
There was a problem hiding this comment.
the thing is, this can be completely done inside the EncoderDecoderCache.
SOmething like:
class EncoderDecoderCache:
def __init__(self, past_key_values):
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
... # do the init
return self
if not isinstance(past_key_values, EncoderDecoderCache):
return self.from_legacy_cache(past_key_values)There was a problem hiding this comment.
then the only check in the model to do is :
if use_cache or past_key_values is not None:
There was a problem hiding this comment.
@ArthurZucker that handles cache initialization, but doesn't solve the part where the whisper class may selectively return a Cache in certain circumstances (as opposed to an EncoderDecoderCache).
IMO, the root issue comes from the PR that introduced this Cache<>EncoderDecoderCache logic: to minimize cross-class dependencies, the more recent decoder-only model should wrap its Cache into an EncoderDecoderCache before calling the main trunk of Whisper. This is a solution without if/elses where all main classes can be kept without cross-references :)
There was a problem hiding this comment.
I feel like this change to the cache is out of scope for this PR. It's affecting multiple models not related to whisper (e.g. bart). It would make more sense to tackle this in a separate PR that covers all of these models.
There are either two options imo:
- Update cache logic to move the decoder-only cache logic out (Joao's suggestion)
- Update docstrings to reflect the logic that happens here
There was a problem hiding this comment.
Yeah, makes sense to do it in another PR :)
There was a problem hiding this comment.
Yeah, in general we want to move away from sub modules returning cache classes!
gante
left a comment
There was a problem hiding this comment.
(Happy with the PR, with the exception of the related but out-of-scope issue discussed above)
|
Please see my comments #38235 (comment) here before merging! I believe we can make it much cleaner to set a proper example for all similar models! |
Cyrilvallez
left a comment
There was a problem hiding this comment.
LGTM, just left 2 last comments to try to simplify the code as much as possible (especially the part on general FA2 attention as it would unnecessarily impact every model here I think)! Thanks for making the changes 🤗
| if attention_mask is not None and attention_mask.ndim == 2: | ||
| attention_mask = attention_mask[:, : key.shape[-2]] |
There was a problem hiding this comment.
Is this still needed based on latest mask refactors in the modeling? It should already have been taken care of upstream in the mask creation function
There was a problem hiding this comment.
Seems to be an old relic from the mask creations which caused this :D removed it
| cache_position=cache_position, | ||
| past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, |
There was a problem hiding this comment.
Here we should not need the if/else with the change to the Cache itself 🤗
There was a problem hiding this comment.
Yup good point! Also changed to just return past_kvs
ArthurZucker
left a comment
There was a problem hiding this comment.
Very happy with the changes! 🤗
| cache_position, | ||
| past_key_values.self_attention_cache if past_key_values is not None else None, | ||
| output_attentions, | ||
| causal_mask = create_causal_mask( |
There was a problem hiding this comment.
haha damn the change is soooo much cleaner!
|
Running slow tests locally and then merge (if nothings broken) Edit: slow tests pass as expected |
Whisper attention refactor according to the same strategies applied in #38108
Also, several fixes on Whisper along the way, reducing the number of failed tests to 3 (disregarding skipped tests):
test_small_longform_timestamps_generationtest_tiny_token_timestamp_batch_generationtest_whisper_longform_multi_batch_hard_prev_cond