Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,10 @@ def _flash_attention_forward(
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif (
position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1
):
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def forward(
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)

cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down
59 changes: 59 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,59 @@ def test_generate_output_type(self, return_dict_in_generate):
)
assert isinstance(pred_ids, expected_output_type)

@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_reuse_cache(self):
max_new_tokens = 2
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

dummy_input = inputs_dict[model_class.main_input_name][..., :10]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)

# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)

# run generate once to get filled cache
output = model.generate(
dummy_input,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
past_key_values = output.past_key_values

# Try to continue generation from where we left, given that we have more than 1 new token to process
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
_ = model.generate(
dummy_input,
decoder_input_ids=output.sequences,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
past_key_values=past_key_values,
)


@require_torch
@require_torchaudio
Expand Down Expand Up @@ -4066,6 +4119,12 @@ def test_retain_grad_hidden_states_attentions(self):
def test_save_load_fast_init_from_base(self):
pass

@unittest.skip(
reason="FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)
def test_flash_attn_2_generate_reuse_cache(self):
pass

@unittest.skip(
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4284,6 +4284,62 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_reuse_cache(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

max_new_tokens = 2
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)

# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)

# run generate once to get filled cache
output = model.generate(
dummy_input,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
past_key_values = output.past_key_values

# Try to continue generation from where we left, given that we have more than 1 new token to process
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1)
_ = model.generate(
dummy_input_updated,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
past_key_values=past_key_values,
)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down