Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/en/model_doc/whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```

Whisper is compatible with the following optimisations:
Whisper is compatible with the following optimisations for both short and long-form generation:
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
Expand Down Expand Up @@ -101,7 +101,8 @@ As an example, the following codesnippet enables SDPA and `torch.compile` for up
... ).input_features

>>> # Compile the forward pass
>>> _ = model.generate(input_features)
>>> for _ in range(2):
>>> model.generate(input_features)

>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)
Expand Down
72 changes: 60 additions & 12 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,24 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att


def _pad_to_max_length(
current_segments, pad_token_id, device, padding="right", bos_token_tensor=None, cut_off_length=None
current_segments,
pad_token_id,
device,
padding_side="right",
padding="longest",
bos_token_tensor=None,
cut_off_length=None,
):
max_total_length = 0
sequences = []
if padding not in ["right", "left"]:
raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}")

if padding_side not in ["right", "left"]:
Comment thread
sanchit-gandhi marked this conversation as resolved.
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")

if padding not in ["longest", "max_length"]:
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
elif padding == "max_length" and cut_off_length is None:
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")

for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
Expand All @@ -150,9 +162,10 @@ def _pad_to_max_length(
else:
sequences.append(torch.tensor([], device=device))

max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding == "right" else (pad_length, 0)
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)

sequences = torch.stack(sequences, dim=0)
Expand Down Expand Up @@ -672,6 +685,7 @@ def generate(
return_token_timestamps=return_token_timestamps,
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
is_shortform=is_shortform,
batch_size=batch_size,
Comment thread
sanchit-gandhi marked this conversation as resolved.
Outdated
kwargs=kwargs,
)

Expand Down Expand Up @@ -712,7 +726,7 @@ def generate(
)

sequences = _pad_to_max_length(
final_segments, generation_config.pad_token_id, device=self.device, padding="right"
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
)

# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
Expand Down Expand Up @@ -775,6 +789,7 @@ def generate_with_fallback(
return_token_timestamps,
do_condition_on_prev_tokens,
is_shortform,
batch_size,
Comment thread
sanchit-gandhi marked this conversation as resolved.
Outdated
kwargs,
):
kwargs = copy.copy(kwargs)
Expand All @@ -798,6 +813,22 @@ def generate_with_fallback(
for key in ["do_sample", "temperature", "num_beams"]:
if key in generate_kwargs:
del generate_kwargs[key]

cur_bsz = decoder_input_ids.shape[0]
if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
decoder_input_ids = F.pad(
decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
)
if generate_kwargs.get("decoder_attention_mask") is not None:
generate_kwargs["decoder_attention_mask"] = F.pad(
generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
)
if generate_kwargs.get("encoder_outputs") is not None:
generate_kwargs["encoder_outputs"] = F.pad(
generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
)
Comment thread
ArthurZucker marked this conversation as resolved.

seek_outputs = super().generate(
segment_input,
generation_config=generation_config,
Expand All @@ -820,6 +851,10 @@ def generate_with_fallback(
is_shortform=is_shortform,
)

if cur_bsz < batch_size:
seek_sequences = seek_sequences[:cur_bsz]
seek_outputs = seek_outputs[:cur_bsz]
Comment thread
sanchit-gandhi marked this conversation as resolved.
Outdated

# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
# Loop over each decoded audio individually as each decoding can be of a different length
new_fallback_index_map = []
Expand Down Expand Up @@ -925,17 +960,27 @@ def split_by_batch_index(values, key, batch_idx, is_shortform):
if not is_shortform:
# we don't save `past_key_values` as this is too costly for longform
return None
elif isinstance(values, EncoderDecoderCache):
all_past_key_values = []
for layer_idx in range(self.config.decoder_layers):
layer_past_key_values = []
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
for v in [cache_cls.key_cache, cache_cls.value_cache]:
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)
else:
return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values)))
all_past_key_values = []
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a re-factor to split the single line list iteration over several lines, in order to be more verbose

for v in range(len(values)):
layer_past_key_values = []
for w in values[v]:
layer_past_key_values.append(w[batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)

return values[batch_idx].cpu()

sequence_tokens = seek_outputs["sequences"]

if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None:
if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache):
seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache()

seek_outputs = [
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0])
Expand Down Expand Up @@ -1613,11 +1658,14 @@ def _prepare_decoder_input_ids(
one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None

padding = "max_length" if generation_config.cache_implementation == "static" else "longest"

prev_tokens = _pad_to_max_length(
active_segments,
generation_config.pad_token_id,
device=device,
padding="left",
padding_side="left",
padding=padding,
bos_token_tensor=prev_ids,
cut_off_length=cut_off_length,
)
Expand Down
34 changes: 33 additions & 1 deletion src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,8 +1835,10 @@ def prepare_inputs_for_generation(

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
if decoder_position_ids is not None:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format)

if cache_position is None:
cache_position = torch.arange(
Expand All @@ -1845,6 +1847,36 @@ def prepare_inputs_for_generation(
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1] :]

# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
decoder_input_ids = decoder_input_ids.contiguous()

Comment thread
sanchit-gandhi marked this conversation as resolved.
if (
isinstance(past_key_values, EncoderDecoderCache)
and (
isinstance(past_key_values.self_attention_cache, StaticCache)
or isinstance(past_key_values.cross_attention_cache, StaticCache)
)
and decoder_attention_mask is not None
and decoder_attention_mask.ndim == 2
):
batch_size, sequence_length = decoder_input_ids.shape
device = decoder_input_ids.device

dtype = self.proj_out.weight.dtype
min_dtype = torch.finfo(dtype).min

decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
decoder_attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.self_attention_cache.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
Expand Down
60 changes: 60 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3386,6 +3386,66 @@ def test_tiny_static_generation(self):
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()

@slow
def test_tiny_static_generation_long_form(self):
import torch._dynamo.config

# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
torch._dynamo.config.cache_size_limit = 4

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)

dataset = load_dataset("distil-whisper/meanwhile", "default")["test"]
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
input_speech = [audio["array"] for audio in dataset[2:4]["audio"]]

inputs = processor(
input_speech,
return_tensors="pt",
padding="longest",
truncation=False,
return_attention_mask=True,
sampling_rate=16_000,
)
inputs = inputs.to(torch_device)
Comment thread
sanchit-gandhi marked this conversation as resolved.

gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second time step
"logprob_threshold": -1.0,
"num_beams": 1,
}

set_seed(42)
eager_generated_ids = model.generate(**inputs, **gen_kwargs)

# compile the forward pass and assert equivalence
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

set_seed(42)
static_generated_ids = model.generate(**inputs, **gen_kwargs)
assert (eager_generated_ids == static_generated_ids).all()

# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
input_features = inputs.input_features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
attention_mask = inputs.attention_mask[permutation_idx, ...]

set_seed(42)
static_generated_ids = model.generate(input_features, attention_mask=attention_mask, **gen_kwargs)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()


def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
Expand Down