From e9ca1ea2ce98fee2dd9b5716c77ac5e96007a34e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 18:07:40 +0000 Subject: [PATCH 01/18] no cache positions in the public api --- .../models/llama/modeling_llama.py | 45 ++++------- tests/test_cache_utils.py | 74 ++++++++++++++++++- 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8e494adefc2d..a945805b1355 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -357,7 +357,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -446,7 +446,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -625,7 +625,6 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -645,7 +644,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -943,7 +942,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -972,15 +970,21 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: + if position_ids is None: if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( + raise ValueError("position_ids is a required argument when using StaticCache.") + position_ids = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + ).unsqueeze(0) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding + # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. + cache_position = torch.max(position_ids, dim=0).values + if attention_mask is None: + padded_positions = 0 + else: + padded_positions = torch.sum(attention_mask == 0, dim=1).min() + cache_position = cache_position + padded_positions causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1130,7 +1134,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1174,7 +1177,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) hidden_states = outputs[0] @@ -1248,24 +1250,10 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids.masked_fill_(attention_mask == 0, 0) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1278,7 +1266,6 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), - "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 6d31d63e82ef..0b194417bb5e 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -291,7 +291,7 @@ def test_sink_cache_iterative_prompts(self): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -331,7 +331,7 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_right(self, attn_implementation): EXPECTED_GENERATION = [ "The best color isЋ the one that complements the skin tone of", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -382,6 +382,76 @@ def call(input_ids, **kwargs): with self.subTest(f"{attn_implementation}, static, compiled"): self.assertListEqual(decoded, EXPECTED_GENERATION) + def test_dynamic_cache_extra_left_padding(self): + """Tests that adding extra left-padding does not affect the generation with the dynamic cache""" + EXPECTED_GENERATION = [ + "The best color is the one that complements the skin tone of the", + "We should not undermind the issues at hand.\nWe should not undermind the issues", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + ).to(torch_device) + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + # Now with extra left-padding + inputs_expanded = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], + padding=True, + return_tensors="pt", + pad_to_multiple_of=32, + ).to(model.device) + self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) + gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + def test_static_cache_extra_left_padding(self): + """Tests that adding extra left-padding does not affect the generation with the static cache""" + EXPECTED_GENERATION = [ + "The best color is the one that complements the skin tone of the", + "We should not undermind the issues at hand.\nWe should not undermind the issues", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + ).to(torch_device) + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + model.generation_config.cache_implementation = "static" + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + # Now with extra left-padding + inputs_expanded = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], + padding=True, + return_tensors="pt", + pad_to_multiple_of=32, + ).to(model.device) + self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) + gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + @unittest.skip("TODO @gante static cache's does not support beam search yet") def test_static_cache_beam_search(self): pass From 3b7fbfbb1e82ac741beb280b43fc3f8a15978f45 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 18:14:24 +0000 Subject: [PATCH 02/18] propagate changes to gemma --- src/transformers/generation/utils.py | 7 +-- .../models/gemma/modeling_gemma.py | 44 +++++++------------ 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d337e5593440..d878fc8ebdd4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -678,8 +678,6 @@ def _update_model_kwargs_for_generation( dim=-1, ) - model_kwargs["cache_position"] = model_inputs.get("cache_position", None) - return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -4931,9 +4929,8 @@ def _split_model_inputs( # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] + bool_keys = [k for k in keys if isinstance(model_input[k], bool)] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k != "encoder_outputs"] # we split the tensors and tuples of tensors data_split_list = [ diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 165ef5a05451..fd58d1bb1765 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -247,7 +247,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -334,7 +334,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -533,7 +533,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -835,7 +835,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -864,13 +863,21 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: - cache_position = torch.arange( + if position_ids is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("position_ids is a required argument when using StaticCache.") + position_ids = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + ).unsqueeze(0) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding + # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. + cache_position = torch.max(position_ids, dim=0).values + if attention_mask is None: + padded_positions = 0 + else: + padded_positions = torch.sum(attention_mask == 0, dim=1).min() + cache_position = cache_position + padded_positions causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1025,7 +1032,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1069,7 +1075,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) hidden_states = outputs[0] @@ -1137,24 +1142,10 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids.masked_fill_(attention_mask == 0, 0) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1167,7 +1158,6 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), - "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, From 694b26580c7d159df4dc8261815afdf8a6f0221f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 18:19:03 +0000 Subject: [PATCH 03/18] should not have been deleted --- src/transformers/models/llama/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a945805b1355..e2fa5249fb41 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -625,6 +625,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() From 75aebbe0e4f9fed6a73eb74c2e23d58b59bb06bc Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 19:03:44 +0000 Subject: [PATCH 04/18] more precise padded offset calculation --- src/transformers/models/gemma/modeling_gemma.py | 9 ++++++--- src/transformers/models/llama/modeling_llama.py | 9 ++++++--- tests/test_cache_utils.py | 6 +++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index fd58d1bb1765..133049dad52a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -874,10 +874,13 @@ def forward( # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. cache_position = torch.max(position_ids, dim=0).values if attention_mask is None: - padded_positions = 0 + padded_offset = 0 else: - padded_positions = torch.sum(attention_mask == 0, dim=1).min() - cache_position = cache_position + padded_positions + padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = torch.cat( + (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + )[-cache_position.shape[0] - 1 : -1] + cache_position = cache_position + padded_offset causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e2fa5249fb41..6afd2c08cdfb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -982,10 +982,13 @@ def forward( # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. cache_position = torch.max(position_ids, dim=0).values if attention_mask is None: - padded_positions = 0 + padded_offset = 0 else: - padded_positions = torch.sum(attention_mask == 0, dim=1).min() - cache_position = cache_position + padded_positions + padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = torch.cat( + (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + )[-cache_position.shape[0] - 1 : -1] + cache_position = cache_position + padded_offset causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 0b194417bb5e..a134e916630b 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -436,9 +436,9 @@ def test_static_cache_extra_left_padding(self): model.generation_config.cache_implementation = "static" - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - self.assertListEqual(decoded, EXPECTED_GENERATION) + # gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + # decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + # self.assertListEqual(decoded, EXPECTED_GENERATION) # Now with extra left-padding inputs_expanded = tokenizer( From 88d597b88bb0c948ae985d3a252b5ab341fd6cb3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 19:20:14 +0000 Subject: [PATCH 05/18] attention mask dtype is sometimes wrong in the tests --- src/transformers/models/gemma/modeling_gemma.py | 4 ++-- src/transformers/models/llama/modeling_llama.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 133049dad52a..70bd229d8c8b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -876,9 +876,9 @@ def forward( if attention_mask is None: padded_offset = 0 else: - padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) padded_offset = torch.cat( - (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) )[-cache_position.shape[0] - 1 : -1] cache_position = cache_position + padded_offset diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6afd2c08cdfb..f147fb6d844c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -984,9 +984,9 @@ def forward( if attention_mask is None: padded_offset = 0 else: - padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) padded_offset = torch.cat( - (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) )[-cache_position.shape[0] - 1 : -1] cache_position = cache_position + padded_offset From e499ac9bd60a65efbbc3e1b244b0e2ac74b26f1f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 18:41:02 +0000 Subject: [PATCH 06/18] get_seq_length() working --- src/transformers/cache_utils.py | 13 ++--- .../models/llama/modeling_llama.py | 48 +++++++++++-------- tests/test_cache_utils.py | 6 +-- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 87d24c6cf663..1ed5780a5ef1 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -398,16 +398,9 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) - - def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # check the first batch member and the first head only. + return (self.key_cache[0, 0].sum(dim=-1) != 0).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f147fb6d844c..4e71fcffbba4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -966,29 +966,22 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): + if use_cache: + if past_key_values is not None and not isinstance(past_key_values, Cache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # non-static cache + if past_key_values is not None: past_seen_tokens = past_key_values.get_seq_length() - + # static cache + elif past_key_values is None: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 if position_ids is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("position_ids is a required argument when using StaticCache.") - position_ids = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ).unsqueeze(0) - - # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding - # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. - cache_position = torch.max(position_ids, dim=0).values - if attention_mask is None: - padded_offset = 0 - else: - padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) - padded_offset = torch.cat( - (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) - )[-cache_position.shape[0] - 1 : -1] - cache_position = cache_position + padded_offset + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1220,12 +1213,22 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + # With static cache, the `past_key_values` is None + has_static_cache = False + if past_key_values is None: + has_static_cache = True + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() + # TODO joao: find a better way to track the total number of tokens seen in the static cache + if max_cache_length is not None: + past_length = cache_length + else: + past_length = past_key_values.seen_tokens else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1267,6 +1270,9 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids.contiguous(), diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index a134e916630b..0b194417bb5e 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -436,9 +436,9 @@ def test_static_cache_extra_left_padding(self): model.generation_config.cache_implementation = "static" - # gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - # decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - # self.assertListEqual(decoded, EXPECTED_GENERATION) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) # Now with extra left-padding inputs_expanded = tokenizer( From 6cc17ecf7fd9770d202bd5f8707169240e178757 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 19:05:23 +0000 Subject: [PATCH 07/18] nits --- src/transformers/cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1ed5780a5ef1..250b25d5b010 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -399,8 +399,8 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # check the first batch member and the first head only. - return (self.key_cache[0, 0].sum(dim=-1) != 0).sum() + # limit the check to the first batch member and head dimension. + return (self.key_cache[0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" From 6e4b511e214ea7e0436d7a88bac0c3fbc9ece547 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 19:20:18 +0000 Subject: [PATCH 08/18] gemma --- .../models/gemma/modeling_gemma.py | 47 ++++++++++--------- .../models/llama/modeling_llama.py | 15 +++--- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 70bd229d8c8b..9bf9dd87c040 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -858,29 +858,19 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 if position_ids is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("position_ids is a required argument when using StaticCache.") - position_ids = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ).unsqueeze(0) - - # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding - # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. - cache_position = torch.max(position_ids, dim=0).values - if attention_mask is None: - padded_offset = 0 - else: - padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) - padded_offset = torch.cat( - (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) - )[-cache_position.shape[0] - 1 : -1] - cache_position = cache_position + padded_offset + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1111,12 +1101,22 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + # With static cache, the `past_key_values` is None + has_static_cache = False + if past_key_values is None: + has_static_cache = True + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() + # TODO joao: find a better way to track the total number of tokens seen in the static cache + if max_cache_length is not None: + past_length = cache_length + else: + past_length = past_key_values.seen_tokens else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1158,6 +1158,9 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids.contiguous(), diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4e71fcffbba4..a5479c99ba43 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -967,16 +967,13 @@ def forward( past_seen_tokens = 0 if use_cache: - if past_key_values is not None and not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # non-static cache - if past_key_values is not None: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - # static cache - elif past_key_values is None: - static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) - if static_cache is not None: - past_seen_tokens = static_cache.get_seq_length() # `torch.compile`-friendly `torch.arange` from a shape cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 From 232da2a1e71427dd01b1a66aeeff4e939e7576b5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 19:36:27 +0000 Subject: [PATCH 09/18] bc nit --- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9bf9dd87c040..155e1b9c18e4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1145,7 +1145,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) + position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a5479c99ba43..b2f266861ecf 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1254,7 +1254,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) + position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] From 04d53a75a65de6a0c625bd751d063519d09baef9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 28 Feb 2024 15:21:29 +0000 Subject: [PATCH 10/18] explicit cache_positions (implicit working when not passed) --- src/transformers/cache_utils.py | 2 ++ src/transformers/generation/utils.py | 7 +++++-- src/transformers/models/gemma/modeling_gemma.py | 17 +++++++++++++++-- src/transformers/models/llama/modeling_llama.py | 17 +++++++++++++++-- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 250b25d5b010..382fef1085e9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -400,6 +400,8 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. + # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after + # https://github.com/pytorch/pytorch/issues/120248 is fixed return (self.key_cache[0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d878fc8ebdd4..d337e5593440 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -678,6 +678,8 @@ def _update_model_kwargs_for_generation( dim=-1, ) + model_kwargs["cache_position"] = model_inputs.get("cache_position", None) + return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -4929,8 +4931,9 @@ def _split_model_inputs( # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool)] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k != "encoder_outputs"] + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] + keys_to_ignore = ["cache_position", "encoder_outputs"] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] # we split the tensors and tuples of tensors data_split_list = [ diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 155e1b9c18e4..8a82896f3dcc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -835,6 +835,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -867,8 +868,11 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - # `torch.compile`-friendly `torch.arange` from a shape - cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + if cache_position is None: + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1025,6 +1029,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1068,6 +1073,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1149,6 +1155,12 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + position_ids = position_ids.contiguous() if position_ids is not None else None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1164,6 +1176,7 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b2f266861ecf..56156f291682 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -943,6 +943,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -975,8 +976,11 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - # `torch.compile`-friendly `torch.arange` from a shape - cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + if cache_position is None: + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1128,6 +1132,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1171,6 +1176,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1258,6 +1264,12 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + position_ids = position_ids.contiguous() if position_ids is not None else None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1273,6 +1285,7 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, From 20baebdfab8d0dd34e1f0d9e385dd9895379e090 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 28 Feb 2024 16:31:40 +0000 Subject: [PATCH 11/18] add test for implicit cache_position --- .../models/gemma/modeling_gemma.py | 4 + .../models/llama/modeling_llama.py | 4 + tests/test_modeling_common.py | 74 ++++++++++++++++++- 3 files changed, 81 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8a82896f3dcc..723b16abb8cd 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -782,6 +782,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 56156f291682..f1d3c4245187 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -892,6 +892,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 32f6abcbe3aa..5b63a0667da8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -103,7 +103,7 @@ from safetensors.torch import save_file as safe_save_file from torch import nn - from transformers import MODEL_MAPPING, AdaptiveEmbedding + from transformers import MODEL_MAPPING, AdaptiveEmbedding, StaticCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -3937,6 +3937,78 @@ def test_flash_attn_2_from_config(self): self.assertFalse(fa2_correctly_converted) + @require_torch_gpu + @slow + def test_implicit_cache_position(self): + """ + Tests that passing the correct cache_position yields the same results as passing cache_position=None, i.e. that + inference with implicit cache_position is working. + """ + for model_class in self.all_generative_model_classes: + if not hasattr(model_class, "_setup_cache"): + continue + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + + input_ids = inputs_dict["input_ids"].to(torch_device) + + def run_2_forward_passes_with_cache(model, input_ids, static_cache, compile): + # runs two generate-style forward passes, to ensure cudagraphs need two different values of implicit + # `cache_position` to work correctly + if static_cache: + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + if compile: + model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + + # Implicit cache_positions + logits_implicit = [] + outputs = model(input_ids, cache_position=None) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_implicit.append(outputs.logits) + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=None, past_key_values=outputs.past_key_values) + logits_implicit.append(outputs.logits) + + if static_cache: + # Restart the cache + model._reset_cache() + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + # Explicit cache_positions + logits_explicit = [] + cache_positions = torch.arange(input_ids.shape[1], dtype=torch.long, device=torch_device) + outputs = model(input_ids, cache_position=cache_positions) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_explicit.append(outputs.logits) + cache_positions = cache_positions[-1:] + 1 + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=cache_positions, past_key_values=outputs.past_key_values) + logits_explicit.append(outputs.logits) + + if static_cache: + model._reset_cache() + + # Confirm that explicit and implicity cache_positions yield the same results + for idx in range(len(logits_implicit)): + self.assertTrue(torch.allclose(logits_implicit[idx], logits_explicit[idx])) + + # dynamic cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=False, compile=False) + + # eager static cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=False) + + # compiled static cache [to confirm that it works with cuda graphs] + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=True) + global_rng = random.Random() From 646f150ac6453eac8202163091685b66e52da55c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 28 Feb 2024 19:08:55 +0000 Subject: [PATCH 12/18] deprecate seen_tokens --- src/transformers/cache_utils.py | 23 +++++++++++++++---- .../models/gemma/modeling_gemma.py | 17 +++++++------- .../models/llama/modeling_llama.py | 17 +++++++------- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 382fef1085e9..13bac74c986c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -4,6 +4,10 @@ import torch from .configuration_utils import PretrainedConfig +from .utils import logging + + +logger = logging.get_logger(__name__) @dataclass @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.40. Use the `cache_position` " + "variable instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + class DynamicCache(Cache): """ @@ -69,7 +84,7 @@ class DynamicCache(Cache): def __init__(self) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ @@ -121,7 +136,7 @@ def update( """ # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None: self.window_length = window_length self.num_sink_tokens = num_sink_tokens self.cos_sin_cache = {} - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen @staticmethod def _rotate_half(x): @@ -272,7 +287,7 @@ def update( # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # [bsz, num_heads, seq_len, head_dim] if len(self.key_cache) <= layer_idx: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 723b16abb8cd..45238b2e4dbc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -877,6 +877,7 @@ def forward( cache_position = ( torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1109,24 +1110,24 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False if past_key_values is None: - has_static_cache = True past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() + past_length = ( + cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() + ) max_cache_length = past_key_values.get_max_length() - # TODO joao: find a better way to track the total number of tokens seen in the static cache - if max_cache_length is not None: - past_length = cache_length - else: - past_length = past_key_values.seen_tokens + cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f1d3c4245187..c4ec236d9938 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -985,6 +985,7 @@ def forward( cache_position = ( torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1218,24 +1219,24 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False if past_key_values is None: - has_static_cache = True past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() + past_length = ( + cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() + ) max_cache_length = past_key_values.get_max_length() - # TODO joao: find a better way to track the total number of tokens seen in the static cache - if max_cache_length is not None: - past_length = cache_length - else: - past_length = past_key_values.seen_tokens + cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None From 5f67182099f05fdef0b32f863bc0ef36bb966b34 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 Feb 2024 10:22:28 +0000 Subject: [PATCH 13/18] tmp commit --- src/transformers/generation/utils.py | 76 ++++++++++++++++++---------- src/transformers/utils/__init__.py | 1 + tests/generation/test_utils.py | 44 ++++++++++++++++ 3 files changed, 93 insertions(+), 28 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d337e5593440..aa0c2edeb380 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,7 +34,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging +from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -1217,6 +1217,50 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + def _prepare_generation_config(self, generation_config: GenerationConfig, **kwargs: Dict) -> Tuple[GenerationConfig, Dict]: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. + """ + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # three conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same); + # 3) the user must have set generation parameters in the model config. + # NOTE: `torch.compile` can't compile `hash`, so this feature is disabled during compilation. + if ( + not is_torchdynamo_compiling() + and self.generation_config._from_model_config + and self.generation_config._original_object_hash == hash(self.generation_config) + and self.config._has_non_default_generation_parameters() + ): + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will permanently mutate the object with `.update`. As such, passing arguments through `kwargs` is disabled. + if is_torchdynamo_compiling(): + generate_attributes_in_kwargs = [key for key in kwargs.keys() if hasattr(generation_config, key)] + if len(generate_attributes_in_kwargs) > 0: + raise ValueError( + "`torch.compile` exception: all generation configuration attributes must be passed within a " + f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})." + ) + else: + generation_config = copy.deepcopy(generation_config) + + model_kwargs = generation_config.update(**kwargs) + return generation_config, model_kwargs + @torch.no_grad() def generate( self, @@ -1322,34 +1366,10 @@ def generate( else: synced_gpus = False - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call. All + # unused kwargs must be model kwargs. self._validate_model_class() - - # priority: `generation_config` argument > `model.generation_config` (the default generation config) - if generation_config is None: - # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, - # three conditions must be met - # 1) the generation config must have been created from the model config (`_from_model_config` field); - # 2) the generation config must have seen no modification since its creation (the hash is the same); - # 3) the user must have set generation parameters in the model config. - if ( - self.generation_config._from_model_config - and self.generation_config._original_object_hash == hash(self.generation_config) - and self.config._has_non_default_generation_parameters() - ): - new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: - warnings.warn( - "You have modified the pretrained model configuration to control generation. This is a" - " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use and modify the model generation configuration (see" - " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" - ) - self.generation_config = new_generation_config - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 3a3c65a3b7d6..c4f3e494571d 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -193,6 +193,7 @@ is_torchaudio_available, is_torchdistx_available, is_torchdynamo_available, + is_torchdynamo_compiling, is_torchvision_available, is_training_run_on_sagemaker, is_vision_available, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cb224c3c6a9d..9e46c958c9c8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -27,6 +27,7 @@ is_flaky, require_accelerate, require_torch, + require_torch_gpu, require_torch_multi_accelerator, slow, torch_device, @@ -2135,6 +2136,31 @@ def test_new_cache_format(self, num_beams, do_sample): ) ) + @require_torch_gpu + @slow + def test_generate_compile_fullgraph(self): + """Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results""" + for model_class in self.all_generative_model_classes: + if not hasattr(model_class, "_setup_cache"): + continue + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + input_ids = inputs_dict["input_ids"].to(torch_device) + + # dynamic cache + output_dynamic = model.generate(input_ids) + + # eager static cache + model.generation_config.cache_implementation = "static" + output_static = model.generate(input_ids) + self.assertListEqual(output_dynamic.tolist(), output_static.tolist()) + + # compiled static cache + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + output_compiled = compiled_generate(input_ids) + self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences @@ -3638,3 +3664,21 @@ def test_return_unprocessed_logit_scores(self): self.assertTrue(y_prob > 0.001 and n_prob > 0.001) self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + + def test_bad_generate_compilation_flags(self): + """ + Tests that certain parameterization options in `generate` properly raise a custom exception (a `ValueError` + defined in `transformers` instead of general `torch._dynamo.exc.Unsupported`). + """ + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Passing generation_config parameters through kwargs is not supported + with self.assertRaises(ValueError): + compiled_generate(input_ids, max_length=10, do_sample=True, temperature=0.7) From 41a91d93ccbebcab097b1ca1562c57543beb8775 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 Feb 2024 18:02:09 +0000 Subject: [PATCH 14/18] MVP working :D --- src/transformers/generation/utils.py | 198 ++++++++++-------- .../models/llama/modeling_llama.py | 68 +++--- 2 files changed, 149 insertions(+), 117 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index aa0c2edeb380..0aefa9702ee2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -374,7 +374,7 @@ def prepare_inputs_for_generation(self, *args, **kwargs): def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, + bos_token_id: Optional[torch.Tensor] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ @@ -438,7 +438,7 @@ def _prepare_model_inputs( def _maybe_initialize_input_ids_for_generation( self, inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, + bos_token_id: Optional[torch.Tensor] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.LongTensor: """Initializes input ids for generation, if necessary.""" @@ -469,20 +469,29 @@ def _maybe_initialize_input_ids_for_generation( def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, - pad_token_id: Optional[int], - eos_token_id: Optional[Union[int, List[int]]], + pad_token_id: Optional[Optional[torch.Tensor]], + eos_token_id: Optional[Optional[torch.Tensor]], ) -> torch.LongTensor: + # No information for attention mask inference -> return default attention mask + default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + if pad_token_id is None: + return default_attention_mask + + # Otherwise we have may have information -> try to infer the attention mask is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] - is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) + is_pad_token_in_inputs = (pad_token_id is not None) and ( + torch.isin(elements=inputs, test_elements=pad_token_id).any() + ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + torch.isin(elements=eos_token_id, test_elements=pad_token_id).any() + ) - # Check if input is input_ids and padded -> only then is attention_mask defined - if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: - return inputs.ne(pad_token_id).long() - else: - return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + can_infer_attention_mask = is_input_ids * is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs.ne(pad_token_id).long() + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None @@ -524,8 +533,7 @@ def _prepare_decoder_input_ids_for_generation( batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: Union[int, List[int]] = None, - bos_token_id: int = None, + decoder_start_token_id: Optional[torch.Tensor] = None, device: torch.device = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" @@ -539,20 +547,14 @@ def _prepare_decoder_input_ids_for_generation( decoder_input_ids = None # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device - if isinstance(decoder_start_token_id, list): - if len(decoder_start_token_id) != batch_size: + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: raise ValueError( - f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" + f"`decoder_start_token_id` expcted to have length {batch_size} but got {decoder_start_token_id.shape[0]}" ) - decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) - decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) - else: - decoder_input_ids_start = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id - ) + decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: @@ -583,7 +585,7 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + self, decoder_start_token_id: Optional[torch.tensor] = None, bos_token_id: Optional[torch.tensor] = None ) -> int: decoder_start_token_id = ( decoder_start_token_id @@ -648,7 +650,7 @@ def _update_model_kwargs_for_generation( model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, - model_inputs: Optional[Dict[str, Any]] = None, + model_inputs=None, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( @@ -678,7 +680,8 @@ def _update_model_kwargs_for_generation( dim=-1, ) - model_kwargs["cache_position"] = model_inputs.get("cache_position", None) + if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 return model_kwargs @@ -1217,7 +1220,9 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) - def _prepare_generation_config(self, generation_config: GenerationConfig, **kwargs: Dict) -> Tuple[GenerationConfig, Dict]: + def _prepare_generation_config( + self, generation_config: GenerationConfig, **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: """ Prepares the base generation config, then applies any generation configuration options from kwargs. """ @@ -1249,6 +1254,7 @@ def _prepare_generation_config(self, generation_config: GenerationConfig, **kwar # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` # will permanently mutate the object with `.update`. As such, passing arguments through `kwargs` is disabled. if is_torchdynamo_compiling(): + model_kwargs = kwargs generate_attributes_in_kwargs = [key for key in kwargs.keys() if hasattr(generation_config, key)] if len(generate_attributes_in_kwargs) > 0: raise ValueError( @@ -1257,10 +1263,42 @@ def _prepare_generation_config(self, generation_config: GenerationConfig, **kwar ) else: generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) - model_kwargs = generation_config.update(**kwargs) return generation_config, model_kwargs + def _prepare_special_tokens( + self, generation_config: GenerationConfig, kwargs_has_attention_mask: bool + ) -> Tuple[Optional[torch.Tensor]]: + """Prepares the special tokens for generation.""" + + # Convert special tokens to tensors (if they exist) + def _tensor_or_none(token): + return torch.tensor(token, device=self.device, dtype=torch.long) if token is not None else None + + bos_token_id = _tensor_or_none(generation_config.bos_token_id) + eos_token_id = _tensor_or_none(generation_config.eos_token_id) + pad_token_id = _tensor_or_none(generation_config.pad_token_id) + decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id) or bos_token_id + + if self.config.is_encoder_decoder and decoder_start_token_id is None: + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_id is None and eos_token_id is not None: + if not kwargs_has_attention_mask: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + if eos_token_id.ndim == 1: + pad_token_id = pad_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") + + return bos_token_id, eos_token_id, pad_token_id, decoder_start_token_id + @torch.no_grad() def generate( self, @@ -1360,12 +1398,6 @@ def generate( - [`~generation.GenerateBeamEncoderDecoderOutput`] """ - if synced_gpus is None: - if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: - synced_gpus = True - else: - synced_gpus = False - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call. All # unused kwargs must be model kwargs. self._validate_model_class() @@ -1373,34 +1405,43 @@ def generate( self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined + if synced_gpus is None: + if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: + synced_gpus = True + else: + synced_gpus = False logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + bos_token_id, eos_token_id, pad_token_id, decoder_start_token_id = self._prepare_special_tokens( + generation_config, kwargs_has_attention_mask + ) # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) batch_size = inputs_tensor.shape[0] + # decoder-only models must use left-padding for generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + pad_token_id is not None + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": @@ -1408,31 +1449,13 @@ def generate( else: model_kwargs["use_cache"] = generation_config.use_cache - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) - requires_attention_mask = "encoder_outputs" not in model_kwargs - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, pad_token_id, eos_token_id ) - # decoder-only models should use left-padding for generation - if not self.config.is_encoder_decoder: - # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` - # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. - if ( - generation_config.pad_token_id is not None - and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 - ): - logger.warning( - "A decoder-only architecture is being used, but right-padding was detected! For correct " - "generation results, please set `padding_side='left'` when initializing the tokenizer." - ) - if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created - # and added to `model_kwargs` + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name ) @@ -1443,8 +1466,7 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + decoder_start_token_id=decoder_start_token_id, device=inputs_tensor.device, ) else: @@ -1488,7 +1510,8 @@ def generate( ) self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + if not is_torchdynamo_compiling(): + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode generation_mode = self._get_generation_mode(generation_config, assistant_model) @@ -1525,6 +1548,7 @@ def generate( prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) + # 10. go into different generation modes if generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: @@ -2419,10 +2443,13 @@ def greedy_search( ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + batch_size, cur_len = input_ids.shape + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) this_peer_finished = False # used by synced_gpus only - while True: + # while True: + while not stopping_criteria(input_ids, scores): if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence @@ -2499,15 +2526,14 @@ def greedy_search( ) # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True + this_peer_finished = unfinished_sequences.max() == 0 # stop if we exceed the maximum length - if stopping_criteria(input_ids, scores): - this_peer_finished = True + # if stopping_criteria(input_ids, scores): + # this_peer_finished = True - if this_peer_finished and not synced_gpus: - break + # if this_peer_finished and not synced_gpus: + # break if streamer is not None: streamer.end() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c4ec236d9938..7649dd56acda 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1084,11 +1084,9 @@ def _update_causal_mask(self, attention_mask, input_tensor): ) if self.config._attn_implementation == "sdpa": - is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) - if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1): - causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to( - dtype - ) + # is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) + # if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1): + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype) return causal_mask @@ -1231,35 +1229,43 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - past_length = ( - cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() - ) - max_cache_length = past_key_values.get_max_length() - cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + attention_based_slicing = attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1] + past_based_slicing = past_length < input_ids.shape[1] + # no_slicing = not attention_based_slicing and not past_based_slicing + input_ids_slice_index = (-(attention_mask.shape[1] - past_length) * attention_based_slicing) + ( + past_length * past_based_slicing + ) + # input_ids_slice_index = (-(attention_mask.shape[1] - past_length) * attention_based_slicing) + (past_length * past_based_slicing) + (0 * no_slicing) + # input_ids = input_ids[:, input_ids_slice_index:] + input_ids = input_ids[:, cache_position] + + # # Keep only the unprocessed tokens: + # # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # # input) + # if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + # input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # # input_ids based on the past_length. + # elif past_length < input_ids.shape[1]: + # input_ids = input_ids[:, past_length:] + # # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + # if ( + # max_cache_length is not None + # and attention_mask is not None + # and cache_length + input_ids.shape[1] > max_cache_length + # ): + # attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1271,8 +1277,8 @@ def prepare_inputs_for_generation( # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + # input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + # cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) position_ids = position_ids.contiguous() if position_ids is not None else None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step From b32fa42719be1a11489f7f6982cd7f0c8523d646 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 14 Mar 2024 18:49:07 +0000 Subject: [PATCH 15/18] sort a few issues (compile hangs?) --- src/transformers/generation/utils.py | 55 ++++++++++++++++------------ 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 09b2a94fbdb5..f1bfe751f308 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,7 +34,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging +from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -1187,7 +1187,7 @@ def _prepare_generation_config( # 1) the generation config must have been created from the model config (`_from_model_config` field); # 2) the generation config must have seen no modification since its creation (the hash is the same); # 3) the user must have set generation parameters in the model config. - # NOTE: `torch.compile` can't compile `hash`, so this feature is disabled during compilation. + # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. if ( not is_torchdynamo_compiling() and self.generation_config._from_model_config @@ -1206,7 +1206,7 @@ def _prepare_generation_config( generation_config = self.generation_config # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` - # will permanently mutate the object with `.update`. As such, passing arguments through `kwargs` is disabled. + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. if is_torchdynamo_compiling(): model_kwargs = kwargs generate_attributes_in_kwargs = [key for key in kwargs.keys() if hasattr(generation_config, key)] @@ -2422,19 +2422,33 @@ def _greedy_search( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - this_peer_finished = False # used by synced_gpus only - # while True: - while not stopping_criteria(input_ids, scores): - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break + max_length = None + for criteria in stopping_criteria: + if isinstance(criteria, MaxLengthCriteria): + max_length = criteria.max_length + break + + this_peer_finished = False + def has_unfinished_sequences(this_peer_finished: bool, cur_len: int) -> bool: + if is_torchdynamo_compiling(): + return cur_len < max_length + else: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + else: + if this_peer_finished: + return False + return True + + while has_unfinished_sequences(this_peer_finished, cur_len): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2499,15 +2513,8 @@ def _greedy_search( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - # stop when each sentence is finished - this_peer_finished = unfinished_sequences.max() == 0 - - # stop if we exceed the maximum length - # if stopping_criteria(input_ids, scores): - # this_peer_finished = True - - # if this_peer_finished and not synced_gpus: - # break + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 if streamer is not None: streamer.end() From 8063ac8d678fac9d19a97a36a28734f59113af95 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 14 Mar 2024 19:19:24 +0000 Subject: [PATCH 16/18] working again :D --- src/transformers/generation/utils.py | 3 +++ .../models/llama/modeling_llama.py | 21 ++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f1bfe751f308..b7e9d3b59d75 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2431,6 +2431,8 @@ def _greedy_search( this_peer_finished = False def has_unfinished_sequences(this_peer_finished: bool, cur_len: int) -> bool: + # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile, + # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria) if is_torchdynamo_compiling(): return cur_len < max_length else: @@ -2515,6 +2517,7 @@ def has_unfinished_sequences(this_peer_finished: bool, cur_len: int) -> bool: unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 if streamer is not None: streamer.end() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a8b502bc498c..39e21452fe1b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1106,17 +1106,18 @@ def _update_causal_mask(self, attention_mask, input_tensor): and attention_mask is not None and attention_mask.device.type == "cuda" ): + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype) # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(input_tensor, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + # is_tracing = ( + # torch.jit.is_tracing() + # or isinstance(input_tensor, torch.fx.Proxy) + # or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + # ) + # if not is_tracing and torch.any(attention_mask != 1): + # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # # Details: https://github.com/pytorch/pytorch/issues/110213 + # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask From eef19f1cc3c3b6c1d85e4b6eaaae50c7fb1f6b35 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 21 Mar 2024 18:31:44 +0000 Subject: [PATCH 17/18] working with torch==2.3.0.dev20240315+cu121 --- src/transformers/generation/utils.py | 33 ++++---- .../models/llama/modeling_llama.py | 79 +++++++++++-------- 2 files changed, 64 insertions(+), 48 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ed9c670c4a50..e1c606822fa9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1797,23 +1797,28 @@ def typeerror(): return result - def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: + def _has_unfinished_sequences(self, this_peer_finished: bool, cur_len, max_length, synced_gpus: bool, device: torch.device) -> bool: """ Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is fed through `this_peer_finished`. ZeRO stage 3-friendly. """ - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: + # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile, + # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria) + if is_torchdynamo_compiling(): + return cur_len < max_length + else: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + elif this_peer_finished: return False - elif this_peer_finished: - return False - return True + return True def contrastive_search(self, *args, **kwargs): logger.warning_once( @@ -2427,8 +2432,8 @@ def _greedy_search( if isinstance(criteria, MaxLengthCriteria): max_length = criteria.max_length break - - while self.has_unfinished_sequences(this_peer_finished, cur_len, synced_gpus, device=input_ids.device): + + while self._has_unfinished_sequences(this_peer_finished, cur_len, max_length, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dc25d319d513..39e21452fe1b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -371,7 +371,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 @@ -656,9 +658,8 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -791,7 +792,7 @@ class LlamaPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True @@ -814,6 +815,12 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) + if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device: + causal_mask = torch.full( + (max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool + ) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + for layer in self.model.layers: device = layer.input_layernorm.weight.device if hasattr(self.config, "_pre_quantization_dtype"): @@ -927,6 +934,12 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False + # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class. + # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`. + causal_mask = torch.full( + (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool + ) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) # Initialize weights and apply final processing self.post_init() @@ -990,7 +1003,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) # embed positions hidden_states = inputs_embeds @@ -1058,27 +1071,25 @@ def forward( # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache - target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 - ) + batch_size, seq_length = input_tensor.shape[:2] + dtype = input_tensor.dtype + device = input_tensor.device + + # support going beyond cached `max_position_embedding` + if seq_length > self.causal_mask.shape[-1]: + causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + # We use the current dtype to avoid any overflows + min_dtype = torch.finfo(dtype).min + causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype + causal_mask = causal_mask.expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: @@ -1086,27 +1097,27 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - offset = cache_position[0] - else: - offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype) + # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + # is_tracing = ( + # torch.jit.is_tracing() + # or isinstance(input_tensor, torch.fx.Proxy) + # or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + # ) + # if not is_tracing and torch.any(attention_mask != 1): + # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # # Details: https://github.com/pytorch/pytorch/issues/110213 + # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask From 3aeb1d43665887384358a32296efab29f4ea622b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 21 Mar 2024 18:52:22 +0000 Subject: [PATCH 18/18] smaller diff --- src/transformers/generation/utils.py | 8 +- .../models/llama/modeling_llama.py | 79 ++++++++----------- 2 files changed, 40 insertions(+), 47 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e1c606822fa9..23abf9d6d89f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1797,7 +1797,9 @@ def typeerror(): return result - def _has_unfinished_sequences(self, this_peer_finished: bool, cur_len, max_length, synced_gpus: bool, device: torch.device) -> bool: + def _has_unfinished_sequences( + self, this_peer_finished: bool, cur_len, max_length, synced_gpus: bool, device: torch.device + ) -> bool: """ Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is fed through `this_peer_finished`. ZeRO stage 3-friendly. @@ -2433,7 +2435,9 @@ def _greedy_search( max_length = criteria.max_length break - while self._has_unfinished_sequences(this_peer_finished, cur_len, max_length, synced_gpus, device=input_ids.device): + while self._has_unfinished_sequences( + this_peer_finished, cur_len, max_length, synced_gpus, device=input_ids.device + ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 39e21452fe1b..dc25d319d513 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -371,9 +371,7 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 @@ -658,8 +656,9 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask - if attention_mask is not None and cache_position is not None: - causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -792,7 +791,7 @@ class LlamaPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True @@ -815,12 +814,6 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) - if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device: - causal_mask = torch.full( - (max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool - ) - self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) - for layer in self.model.layers: device = layer.input_layernorm.weight.device if hasattr(self.config, "_pre_quantization_dtype"): @@ -934,12 +927,6 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class. - # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`. - causal_mask = torch.full( - (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool - ) - self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) # Initialize weights and apply final processing self.post_init() @@ -1003,7 +990,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) # embed positions hidden_states = inputs_embeds @@ -1071,25 +1058,27 @@ def forward( # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - batch_size, seq_length = input_tensor.shape[:2] - dtype = input_tensor.dtype - device = input_tensor.device - - # support going beyond cached `max_position_embedding` - if seq_length > self.causal_mask.shape[-1]: - causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1) - self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) - - # We use the current dtype to avoid any overflows + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min - causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype - causal_mask = causal_mask.expand(batch_size, 1, -1, -1) + sequence_length = input_tensor.shape[1] + if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache + target_length = self.config.max_position_embeddings + else: # dynamic cache + target_length = ( + attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + ) + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: @@ -1097,27 +1086,27 @@ def _update_causal_mask(self, attention_mask, input_tensor): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: + # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with + # cache. In that case, the 4D attention mask attends to the newest tokens only. + if attention_mask.shape[-2] < cache_position[0] + sequence_length: + offset = cache_position[0] + else: + offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" ): - causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype) - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - # is_tracing = ( - # torch.jit.is_tracing() - # or isinstance(input_tensor, torch.fx.Proxy) - # or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - # ) - # if not is_tracing and torch.any(attention_mask != 1): - # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # # Details: https://github.com/pytorch/pytorch/issues/110213 - # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask