[Whisper + beam search] fix usage of beam_indices#38259
[Whisper + beam search] fix usage of beam_indices#38259gante merged 6 commits intohuggingface:mainfrom
beam_indices#38259Conversation
| tensor containing the timestamps in seconds for each predicted token | ||
| """ | ||
| # Create a list with `decoder_layers` elements, each a tensor of shape | ||
| # (batch size, attention_heads, output length, input length). |
There was a problem hiding this comment.
shape comments were incorrect for the case w/beam search
| weight_length = None | ||
|
|
||
| if "beam_indices" in generate_outputs: | ||
| # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths |
There was a problem hiding this comment.
In this if block, I've rewritten comments to better explain what's happening
| # beam search takes `decoder_input_ids` into account in the `beam_indices` length | ||
| # but forgot to shift the beam_indices by the number of `decoder_input_ids` | ||
| beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length]) | ||
| # we actually shift the beam indices here | ||
| beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids] |
There was a problem hiding this comment.
This is not correct with the beam search refactor (#35802): beam_indices was corrected to have the same output length as the other optional outputs (= length of generated tokens)
| # In that case, the `cross_attentions` weights are too long and we have to make sure that they have | ||
| # the right `output_length` | ||
|
|
||
| weights = weights[:, :, :weight_length] |
There was a problem hiding this comment.
redundant: we rebuild weights below with sequence length = range(unrolled_beam_indices.shape[1]) (= weight_length)
| # since the beam search strategy chooses the most probable sequences at the end of the search. | ||
| # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length | ||
| weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() | ||
| weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids |
There was a problem hiding this comment.
root issue of #36093: weight_length is off by 1. The comments in the new version explain why :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Just some nits on the shape comments. Important step to have something to "work" again even if it's not producing the correct output quality-wise at first :)
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
What does this PR do?
Fixes the shape issues reported in #36093, which have been around since the code was added 👀 . It doesn't fix the quality of word timestamp outputs (see e.g. #36632), but rather how we gather the cross attentions from the right beams with beam search, which was broken.
test_tiny_token_timestamp_batch_generationis a test that has the same pattern, beam search + timestamps, and is failing onmainwith the same exception as reported in #36093. This PR does NOT fix that test, but allows the test to move past the shape exception until the output quality checks, which are broken 🙃